From a702fd51349e73d7d21d7c1bba9e90f3e7948a77 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Thu, 5 May 2022 15:55:28 +0200 Subject: [PATCH] starting train --- SuperTagger/Linker/utils.py | 33 ++++++++++++++++----------------- SuperTagger/eval.py | 6 +++++- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py index d13f5dc..aa2ad6e 100644 --- a/SuperTagger/Linker/utils.py +++ b/SuperTagger/Linker/utils.py @@ -4,15 +4,16 @@ from SuperTagger.Linker.AtomTokenizer import AtomTokenizer from SuperTagger.Linker.atom_map import atom_map -def get_atoms_from_category(category, category_to_atoms): - if category in atom_map.keys(): +def category_to_atoms(category, category_to_atoms): + res = [i for i in atom_map.keys() if category in i] + if len(res) > 0: return [category] else: category_cut = re.search(r'\w*\(\d+,(.+),(.+)\)', category) left_side, right_side = category_cut.group(1), category_cut.group(2) - category_to_atoms += get_atoms_from_category(left_side, []) - category_to_atoms += get_atoms_from_category(right_side, []) + category_to_atoms += category_to_atoms(left_side, []) + category_to_atoms += category_to_atoms(right_side, []) return category_to_atoms @@ -22,12 +23,12 @@ def get_atoms_batch(category_batch): for sentence in category_batch: category_to_atoms = [] for category in sentence: - category_to_atoms = get_atoms_from_category(category, category_to_atoms) + category_to_atoms = category_to_atoms(category, category_to_atoms) batch.append(category_to_atoms) return batch -def cut_category_in_symbols(category): +def category_to_atoms_polarity(category): ''' Parameters : category : str of kind AtomCat | CategoryCat @@ -49,13 +50,13 @@ def cut_category_in_symbols(category): if left_side in atom_map.keys(): category_to_polarity.append(False) else: - category_to_polarity += cut_category_in_symbols(left_side) + category_to_polarity += category_to_atoms_polarity(left_side) # for the right side if right_side in atom_map.keys(): category_to_polarity.append(True) else: - category_to_polarity += cut_category_in_symbols(right_side) + category_to_polarity += category_to_atoms_polarity(right_side) # dl = \ elif category.startswith("dl"): @@ -66,18 +67,18 @@ def cut_category_in_symbols(category): if left_side in atom_map.keys(): category_to_polarity.append(True) else: - category_to_polarity += cut_category_in_symbols(left_side) + category_to_polarity += category_to_atoms_polarity(left_side) # for the right side if right_side in atom_map.keys(): category_to_polarity.append(False) else: - category_to_polarity += cut_category_in_symbols(right_side) + category_to_polarity += category_to_atoms_polarity(right_side) return category_to_polarity -def find_pos_neg_idexes(batch_symbols): +def find_pos_neg_idexes(atoms_batch): ''' Parameters : batch_symbols : (batch_size, sequence_length) the batch of symbols @@ -86,11 +87,9 @@ def find_pos_neg_idexes(batch_symbols): (batch_size, max_symbols_in_sentence) boolean tensor indiating pos and ne indexes ''' list_batch = [] - for sentence in batch_symbols: - list_symbols = [] + for sentence in atoms_batch: + list_atoms = [] for category in sentence: - list_symbols.append(cut_category_in_symbols(category)) - list_batch.append(list_symbols) + list_atoms.append(category_to_atoms_polarity(category)) + list_batch.append(list_atoms) return list_batch - - diff --git a/SuperTagger/eval.py b/SuperTagger/eval.py index 426f5e6..07441e8 100644 --- a/SuperTagger/eval.py +++ b/SuperTagger/eval.py @@ -3,6 +3,8 @@ from torch import Tensor from torch.nn import Module from torch.nn.functional import nll_loss, cross_entropy +from SuperTagger.Linker.utils import get_atoms_batch, find_pos_neg_idexes + class SinkhornLoss(Module): def __init__(self): @@ -19,8 +21,10 @@ def mesure_accuracy(batch_axiom_links, axiom_links_pred): axiom_links_pred : (batch_size, max_atoms_type_polarity) """ # Convert batch_axiom_links into list of atoms (batch_size, max_atoms_in_sentence) + atoms_batch = get_atoms_batch(batch_axiom_links) # then convert into atom_vocab_size lists of (batch_size, max atom in one cat) with prefix parcours of graphe + atoms_polarity = find_pos_neg_idexes(atoms_batch) axiom_links_true = "" @@ -30,4 +34,4 @@ def mesure_accuracy(batch_axiom_links, axiom_links_pred): correct_links[axiom_links_pred != axiom_links_true] = 0 num_correct_links = correct_links.sum().item() - return num_correct_links \ No newline at end of file + return num_correct_links/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1]) -- GitLab