Skip to content
Snippets Groups Projects
eval.py 1.25 KiB
import torch
from torch import Tensor
from torch.nn import Module
from torch.nn.functional import nll_loss, cross_entropy
from SuperTagger.Linker.atom_map import atom_map
import re

from SuperTagger.Linker.utils import get_atoms_batch, find_pos_neg_idexes
from SuperTagger.utils import pad_sequence


class SinkhornLoss(Module):
    def __init__(self):
        super(SinkhornLoss, self).__init__()

    def forward(self, predictions, truths):
        return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1)
                   for link, perm in zip(predictions, truths))


def mesure_accuracy(linking_plus_to_minus, axiom_links_pred):
    r"""
    batch_axiom_links : (batch_size, ...)
    axiom_links_pred : (batch_size, max_atoms_type_polarity)
    """
    correct_links = torch.ones(axiom_links_pred.size())
    correct_links[axiom_links_pred != linking_plus_to_minus] = 0
    correct_links[linking_plus_to_minus == -1] = 1
    num_correct_links = correct_links.sum().item()
    num_masked_atoms = len(linking_plus_to_minus[linking_plus_to_minus == -1])

    # diviser par nombre de links
    return (num_correct_links - num_masked_atoms)/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms)