src.evaluation_metric.matching_recall_varied

src.evaluation_metric.matching_recall_varied(pmat_pred: torch.Tensor, pmat_gt: torch.Tensor, ns: torch.Tensor) torch.Tensor[source]

Matching Recall between predicted permutation matrix and ground truth permutation matrix.

\[\text{matching recall} = \frac{tr(\mathbf{X}\cdot {\mathbf{X}^{gt}}^\top)}{\sum \mathbf{X}^{gt}}\]
Parameters
  • pmat_pred\((b\times n_1 \times n_2)\) predicted permutation matrix \((\mathbf{X})\)

  • pmat_gt\((b\times n_1 \times n_2)\) ground truth permutation matrix \((\mathbf{X}^{gt})\)

  • ns\((b\times 2)\) number of nodes in all pairs. We support batched instances with different number of nodes, and ns is required to specify the exact number of nodes of each instance in the batch.

Returns

\((b)\) matching recall