src.evaluation_metric.clustering_accuracy

src.evaluation_metric.clustering_accuracy(pred_clusters: torch.Tensor, gt_classes: torch.Tensor) torch.Tensor[source]

Clustering accuracy for clusters.

\(\mathcal{A}, \mathcal{B}, ...\) are ground truth classes and \(\mathcal{A}^\prime, \mathcal{B}^\prime, ...\) are predicted classes and \(k\) is the number of classes:

\[\text{clustering accuracy} = 1 - \frac{1}{k} \left(\sum_{\mathcal{A}} \sum_{\mathcal{A}^\prime \neq \mathcal{B}^\prime} \frac{|\mathcal{A}^\prime \cap \mathcal{A}| |\mathcal{B}^\prime \cap \mathcal{A}|}{|\mathcal{A}| |\mathcal{A}|} + \sum_{\mathcal{A}^\prime} \sum_{\mathcal{A} \neq \mathcal{B}} \frac{|\mathcal{A}^\prime \cap \mathcal{A}| |\mathcal{A}^\prime \cap \mathcal{B}|}{|\mathcal{A}| |\mathcal{B}|} \right)\]

This metric is proposed by “Wang et al. Clustering-aware Multiple Graph Matching via Decayed Pairwise Matching Composition. AAAI 2020.”

Parameters
  • pred_clusters

    \((b\times n)\) predicted clusters. \(n\): number of instances.

    e.g. [[0,0,1,2,1,2]
          [0,1,2,2,1,0]]
    

  • gt_classes

    \((b\times n)\) ground truth classes

    e.g. [['car','car','bike','bike','person','person'],
          ['bus','bus','cat', 'sofa',  'cat',  'sofa' ]]
    

Returns

\((b)\) clustering accuracy