From a702fd51349e73d7d21d7c1bba9e90f3e7948a77 Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Thu, 5 May 2022 15:55:28 +0200
Subject: [PATCH] starting train

---
 SuperTagger/Linker/utils.py | 33 ++++++++++++++++-----------------
 SuperTagger/eval.py         |  6 +++++-
 2 files changed, 21 insertions(+), 18 deletions(-)

diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py
index d13f5dc..aa2ad6e 100644
--- a/SuperTagger/Linker/utils.py
+++ b/SuperTagger/Linker/utils.py
@@ -4,15 +4,16 @@ from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
 from SuperTagger.Linker.atom_map import atom_map
 
 
-def get_atoms_from_category(category, category_to_atoms):
-    if category in atom_map.keys():
+def category_to_atoms(category, category_to_atoms):
+    res = [i for i in atom_map.keys() if category in i]
+    if len(res) > 0:
         return [category]
     else:
         category_cut = re.search(r'\w*\(\d+,(.+),(.+)\)', category)
         left_side, right_side = category_cut.group(1), category_cut.group(2)
 
-        category_to_atoms += get_atoms_from_category(left_side, [])
-        category_to_atoms += get_atoms_from_category(right_side, [])
+        category_to_atoms += category_to_atoms(left_side, [])
+        category_to_atoms += category_to_atoms(right_side, [])
 
         return category_to_atoms
 
@@ -22,12 +23,12 @@ def get_atoms_batch(category_batch):
     for sentence in category_batch:
         category_to_atoms = []
         for category in sentence:
-            category_to_atoms = get_atoms_from_category(category, category_to_atoms)
+            category_to_atoms = category_to_atoms(category, category_to_atoms)
         batch.append(category_to_atoms)
     return batch
 
 
-def cut_category_in_symbols(category):
+def category_to_atoms_polarity(category):
     '''
     Parameters :
     category : str of kind AtomCat | CategoryCat
@@ -49,13 +50,13 @@ def cut_category_in_symbols(category):
             if left_side in atom_map.keys():
                 category_to_polarity.append(False)
             else:
-                category_to_polarity += cut_category_in_symbols(left_side)
+                category_to_polarity += category_to_atoms_polarity(left_side)
 
             # for the right side
             if right_side in atom_map.keys():
                 category_to_polarity.append(True)
             else:
-                category_to_polarity += cut_category_in_symbols(right_side)
+                category_to_polarity += category_to_atoms_polarity(right_side)
 
         # dl = \
         elif category.startswith("dl"):
@@ -66,18 +67,18 @@ def cut_category_in_symbols(category):
             if left_side in atom_map.keys():
                 category_to_polarity.append(True)
             else:
-                category_to_polarity += cut_category_in_symbols(left_side)
+                category_to_polarity += category_to_atoms_polarity(left_side)
 
             # for the right side
             if right_side in atom_map.keys():
                 category_to_polarity.append(False)
             else:
-                category_to_polarity += cut_category_in_symbols(right_side)
+                category_to_polarity += category_to_atoms_polarity(right_side)
 
     return category_to_polarity
 
 
-def find_pos_neg_idexes(batch_symbols):
+def find_pos_neg_idexes(atoms_batch):
     '''
     Parameters :
     batch_symbols : (batch_size, sequence_length) the batch of symbols
@@ -86,11 +87,9 @@ def find_pos_neg_idexes(batch_symbols):
     (batch_size, max_symbols_in_sentence) boolean tensor indiating pos and ne indexes
     '''
     list_batch = []
-    for sentence in batch_symbols:
-        list_symbols = []
+    for sentence in atoms_batch:
+        list_atoms = []
         for category in sentence:
-            list_symbols.append(cut_category_in_symbols(category))
-        list_batch.append(list_symbols)
+            list_atoms.append(category_to_atoms_polarity(category))
+        list_batch.append(list_atoms)
     return list_batch
-
-
diff --git a/SuperTagger/eval.py b/SuperTagger/eval.py
index 426f5e6..07441e8 100644
--- a/SuperTagger/eval.py
+++ b/SuperTagger/eval.py
@@ -3,6 +3,8 @@ from torch import Tensor
 from torch.nn import Module
 from torch.nn.functional import nll_loss, cross_entropy
 
+from SuperTagger.Linker.utils import get_atoms_batch, find_pos_neg_idexes
+
 
 class SinkhornLoss(Module):
     def __init__(self):
@@ -19,8 +21,10 @@ def mesure_accuracy(batch_axiom_links, axiom_links_pred):
     axiom_links_pred : (batch_size, max_atoms_type_polarity)
     """
     # Convert batch_axiom_links into list of atoms (batch_size, max_atoms_in_sentence)
+    atoms_batch = get_atoms_batch(batch_axiom_links)
 
     # then convert into atom_vocab_size lists of (batch_size, max atom in one cat) with prefix parcours of graphe
+    atoms_polarity = find_pos_neg_idexes(atoms_batch)
 
     axiom_links_true = ""
 
@@ -30,4 +34,4 @@ def mesure_accuracy(batch_axiom_links, axiom_links_pred):
     correct_links[axiom_links_pred != axiom_links_true] = 0
     num_correct_links = correct_links.sum().item()
 
-    return num_correct_links
\ No newline at end of file
+    return num_correct_links/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1])
-- 
GitLab