diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py index 281a7ab2dc74827c3ebee38780c912d5ed64a38a..ef03e0e44d146e8a6ae70bebd9e41123c6b6ba1d 100644 --- a/SuperTagger/Linker/Linker.py +++ b/SuperTagger/Linker/Linker.py @@ -62,13 +62,13 @@ class Linker(Module): ) def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding): - ''' + r''' Parameters : - category_batch : batch of size (batch_size, sequence_length) = output of decoder - sents_embedding - sents_mask - Retturns : - link_weights : batch-size, atom_vocab_size, ...) + atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories + atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities + sents_embedding : output of BERT for context + Returns : + link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) ''' # atoms embedding diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py index c4b7a799135ead5e08fe5d21a03189bd7f7db071..d95192689514640f048838aad2030fcf959272bc 100644 --- a/SuperTagger/Linker/utils.py +++ b/SuperTagger/Linker/utils.py @@ -16,6 +16,14 @@ regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)' def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): + r''' + Parameters : + max_atoms_in_one_type : configuration + atoms_polarity : (batch_size, max_atoms_in_sentence) + batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms + Returns : + batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms + ''' atoms_batch = get_atoms_links_batch(batch_axiom_links) linking_plus_to_minus_all_types = [] for atom_type in list(atom_map.keys())[:-1]: @@ -37,6 +45,13 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): def category_to_atoms_axiom_links(category, categories_to_atoms): + r''' + Parameters : + category + categories_to_atoms : recursive list + Returns : + List of atoms inside the category in prefix order + ''' res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()] if category.startswith("GOAL:"): word, cat = category.split(':') @@ -52,6 +67,11 @@ def category_to_atoms_axiom_links(category, categories_to_atoms): def get_atoms_links_batch(category_batch): + r""" + category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order + Returns : + (batch_size, max_atoms_in_sentence) flattened categories in prefix order + """ batch = [] for sentence in category_batch: categories_to_atoms = [] @@ -67,6 +87,13 @@ def get_atoms_links_batch(category_batch): def category_to_atoms(category, categories_to_atoms): + r''' + Parameters : + category + categories_to_atoms : recursive list + Returns : + List of atoms inside the category in prefix order + ''' res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()] if category.startswith("GOAL:"): word, cat = category.split(':') @@ -84,6 +111,11 @@ def category_to_atoms(category, categories_to_atoms): def get_atoms_batch(category_batch): + r""" + category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order + Returns : + (batch_size, max_atoms_in_sentence) flattened categories in prefix order + """ batch = [] for sentence in category_batch: categories_to_atoms = [] @@ -98,9 +130,9 @@ def get_atoms_batch(category_batch): ######################################################################################### def category_to_atoms_polarity(category, polarity): - ''' + r''' Parameters : - category : str of kind AtomCat | CategoryCat + category : str of kind AtomCat | CategoryCat(dr or dl) Returns : Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes ''' @@ -183,13 +215,12 @@ def category_to_atoms_polarity(category, polarity): def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch): - ''' - Parameters : - batch_symbols : (batch_size, sequence_length) the batch of symbols - + r""" + max_atoms_in_sentence : configuration + atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order Returns : - (batch_size, max_symbols_in_sentence) boolean tensor indiating pos and ne indexes - ''' + (batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order + """ list_batch = [] for sentence in atoms_batch: list_atoms = [] diff --git a/SuperTagger/eval.py b/SuperTagger/eval.py index b287d4b18a382826db7d9695ae566d0fd6df0224..2731514885b6da2bd84661d8c6c2149ad1645b9d 100644 --- a/SuperTagger/eval.py +++ b/SuperTagger/eval.py @@ -20,8 +20,8 @@ class SinkhornLoss(Module): def mesure_accuracy(batch_true_links, axiom_links_pred): r""" - batch_axiom_links : (batch_size, ...) - axiom_links_pred : (batch_size, max_atoms_type_polarity) + batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms + axiom_links_pred : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms """ correct_links = torch.ones(axiom_links_pred.size()) correct_links[axiom_links_pred != batch_true_links] = 0