From 8b0f5bb5881ad04b309246b848a5bdcffb3fafab Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Wed, 11 May 2022 16:55:21 +0200
Subject: [PATCH] adding comments

---
 SuperTagger/Linker/Linker.py | 12 ++++-----
 SuperTagger/Linker/utils.py  | 47 ++++++++++++++++++++++++++++++------
 SuperTagger/eval.py          |  4 +--
 3 files changed, 47 insertions(+), 16 deletions(-)

diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py
index 281a7ab..ef03e0e 100644
--- a/SuperTagger/Linker/Linker.py
+++ b/SuperTagger/Linker/Linker.py
@@ -62,13 +62,13 @@ class Linker(Module):
         )
 
     def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding):
-        '''
+        r'''
         Parameters :
-        category_batch : batch of size (batch_size, sequence_length) = output of decoder
-        sents_embedding
-        sents_mask
-        Retturns :
-        link_weights : batch-size, atom_vocab_size, ...)
+        atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories
+        atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities
+        sents_embedding : output of BERT for context
+        Returns :
+        link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat)
         '''
 
         # atoms embedding
diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py
index c4b7a79..d951926 100644
--- a/SuperTagger/Linker/utils.py
+++ b/SuperTagger/Linker/utils.py
@@ -16,6 +16,14 @@ regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
 
 
 def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
+    r'''
+    Parameters :
+    max_atoms_in_one_type : configuration
+    atoms_polarity : (batch_size, max_atoms_in_sentence)
+    batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms
+    Returns :
+    batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
+    '''
     atoms_batch = get_atoms_links_batch(batch_axiom_links)
     linking_plus_to_minus_all_types = []
     for atom_type in list(atom_map.keys())[:-1]:
@@ -37,6 +45,13 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
 
 
 def category_to_atoms_axiom_links(category, categories_to_atoms):
+    r'''
+    Parameters :
+    category
+    categories_to_atoms : recursive list
+    Returns :
+    List of atoms inside the category in prefix order
+    '''
     res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
     if category.startswith("GOAL:"):
         word, cat = category.split(':')
@@ -52,6 +67,11 @@ def category_to_atoms_axiom_links(category, categories_to_atoms):
 
 
 def get_atoms_links_batch(category_batch):
+    r"""
+    category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
+    Returns :
+     (batch_size, max_atoms_in_sentence) flattened categories in prefix order
+    """
     batch = []
     for sentence in category_batch:
         categories_to_atoms = []
@@ -67,6 +87,13 @@ def get_atoms_links_batch(category_batch):
 
 
 def category_to_atoms(category, categories_to_atoms):
+    r'''
+    Parameters :
+    category
+    categories_to_atoms : recursive list
+    Returns :
+    List of atoms inside the category in prefix order
+    '''
     res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
     if category.startswith("GOAL:"):
         word, cat = category.split(':')
@@ -84,6 +111,11 @@ def category_to_atoms(category, categories_to_atoms):
 
 
 def get_atoms_batch(category_batch):
+    r"""
+    category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
+    Returns :
+     (batch_size, max_atoms_in_sentence) flattened categories in prefix order
+    """
     batch = []
     for sentence in category_batch:
         categories_to_atoms = []
@@ -98,9 +130,9 @@ def get_atoms_batch(category_batch):
 #########################################################################################
 
 def category_to_atoms_polarity(category, polarity):
-    '''
+    r'''
     Parameters :
-    category : str of kind AtomCat | CategoryCat
+    category : str of kind AtomCat | CategoryCat(dr or dl)
     Returns :
     Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
     '''
@@ -183,13 +215,12 @@ def category_to_atoms_polarity(category, polarity):
 
 
 def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch):
-    '''
-    Parameters :
-    batch_symbols : (batch_size, sequence_length) the batch of symbols
-
+    r"""
+    max_atoms_in_sentence : configuration
+    atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
     Returns :
-    (batch_size, max_symbols_in_sentence) boolean tensor indiating pos and ne indexes
-    '''
+     (batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order
+    """
     list_batch = []
     for sentence in atoms_batch:
         list_atoms = []
diff --git a/SuperTagger/eval.py b/SuperTagger/eval.py
index b287d4b..2731514 100644
--- a/SuperTagger/eval.py
+++ b/SuperTagger/eval.py
@@ -20,8 +20,8 @@ class SinkhornLoss(Module):
 
 def mesure_accuracy(batch_true_links, axiom_links_pred):
     r"""
-    batch_axiom_links : (batch_size, ...)
-    axiom_links_pred : (batch_size, max_atoms_type_polarity)
+    batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
+    axiom_links_pred : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
     """
     correct_links = torch.ones(axiom_links_pred.size())
     correct_links[axiom_links_pred != batch_true_links] = 0
-- 
GitLab