-
Caroline DE POURTALES authoredCaroline DE POURTALES authored
eval.py 1.25 KiB
import torch
from torch import Tensor
from torch.nn import Module
from torch.nn.functional import nll_loss, cross_entropy
from SuperTagger.Linker.atom_map import atom_map
import re
from SuperTagger.Linker.utils import get_atoms_batch, find_pos_neg_idexes
from SuperTagger.utils import pad_sequence
class SinkhornLoss(Module):
def __init__(self):
super(SinkhornLoss, self).__init__()
def forward(self, predictions, truths):
return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1)
for link, perm in zip(predictions, truths))
def mesure_accuracy(linking_plus_to_minus, axiom_links_pred):
r"""
batch_axiom_links : (batch_size, ...)
axiom_links_pred : (batch_size, max_atoms_type_polarity)
"""
correct_links = torch.ones(axiom_links_pred.size())
correct_links[axiom_links_pred != linking_plus_to_minus] = 0
correct_links[linking_plus_to_minus == -1] = 1
num_correct_links = correct_links.sum().item()
num_masked_atoms = len(linking_plus_to_minus[linking_plus_to_minus == -1])
# diviser par nombre de links
return (num_correct_links - num_masked_atoms)/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms)