Skip to content
Snippets Groups Projects
Commit 8b0f5bb5 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

adding comments

parent f996b207
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -62,13 +62,13 @@ class Linker(Module): ...@@ -62,13 +62,13 @@ class Linker(Module):
) )
def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding): def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding):
''' r'''
Parameters : Parameters :
category_batch : batch of size (batch_size, sequence_length) = output of decoder atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories
sents_embedding atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities
sents_mask sents_embedding : output of BERT for context
Retturns : Returns :
link_weights : batch-size, atom_vocab_size, ...) link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat)
''' '''
# atoms embedding # atoms embedding
......
...@@ -16,6 +16,14 @@ regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)' ...@@ -16,6 +16,14 @@ regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
r'''
Parameters :
max_atoms_in_one_type : configuration
atoms_polarity : (batch_size, max_atoms_in_sentence)
batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms
Returns :
batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
'''
atoms_batch = get_atoms_links_batch(batch_axiom_links) atoms_batch = get_atoms_links_batch(batch_axiom_links)
linking_plus_to_minus_all_types = [] linking_plus_to_minus_all_types = []
for atom_type in list(atom_map.keys())[:-1]: for atom_type in list(atom_map.keys())[:-1]:
...@@ -37,6 +45,13 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): ...@@ -37,6 +45,13 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
def category_to_atoms_axiom_links(category, categories_to_atoms): def category_to_atoms_axiom_links(category, categories_to_atoms):
r'''
Parameters :
category
categories_to_atoms : recursive list
Returns :
List of atoms inside the category in prefix order
'''
res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()] res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
if category.startswith("GOAL:"): if category.startswith("GOAL:"):
word, cat = category.split(':') word, cat = category.split(':')
...@@ -52,6 +67,11 @@ def category_to_atoms_axiom_links(category, categories_to_atoms): ...@@ -52,6 +67,11 @@ def category_to_atoms_axiom_links(category, categories_to_atoms):
def get_atoms_links_batch(category_batch): def get_atoms_links_batch(category_batch):
r"""
category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
Returns :
(batch_size, max_atoms_in_sentence) flattened categories in prefix order
"""
batch = [] batch = []
for sentence in category_batch: for sentence in category_batch:
categories_to_atoms = [] categories_to_atoms = []
...@@ -67,6 +87,13 @@ def get_atoms_links_batch(category_batch): ...@@ -67,6 +87,13 @@ def get_atoms_links_batch(category_batch):
def category_to_atoms(category, categories_to_atoms): def category_to_atoms(category, categories_to_atoms):
r'''
Parameters :
category
categories_to_atoms : recursive list
Returns :
List of atoms inside the category in prefix order
'''
res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()] res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
if category.startswith("GOAL:"): if category.startswith("GOAL:"):
word, cat = category.split(':') word, cat = category.split(':')
...@@ -84,6 +111,11 @@ def category_to_atoms(category, categories_to_atoms): ...@@ -84,6 +111,11 @@ def category_to_atoms(category, categories_to_atoms):
def get_atoms_batch(category_batch): def get_atoms_batch(category_batch):
r"""
category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
Returns :
(batch_size, max_atoms_in_sentence) flattened categories in prefix order
"""
batch = [] batch = []
for sentence in category_batch: for sentence in category_batch:
categories_to_atoms = [] categories_to_atoms = []
...@@ -98,9 +130,9 @@ def get_atoms_batch(category_batch): ...@@ -98,9 +130,9 @@ def get_atoms_batch(category_batch):
######################################################################################### #########################################################################################
def category_to_atoms_polarity(category, polarity): def category_to_atoms_polarity(category, polarity):
''' r'''
Parameters : Parameters :
category : str of kind AtomCat | CategoryCat category : str of kind AtomCat | CategoryCat(dr or dl)
Returns : Returns :
Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
''' '''
...@@ -183,13 +215,12 @@ def category_to_atoms_polarity(category, polarity): ...@@ -183,13 +215,12 @@ def category_to_atoms_polarity(category, polarity):
def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch): def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch):
''' r"""
Parameters : max_atoms_in_sentence : configuration
batch_symbols : (batch_size, sequence_length) the batch of symbols atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
Returns : Returns :
(batch_size, max_symbols_in_sentence) boolean tensor indiating pos and ne indexes (batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order
''' """
list_batch = [] list_batch = []
for sentence in atoms_batch: for sentence in atoms_batch:
list_atoms = [] list_atoms = []
......
...@@ -20,8 +20,8 @@ class SinkhornLoss(Module): ...@@ -20,8 +20,8 @@ class SinkhornLoss(Module):
def mesure_accuracy(batch_true_links, axiom_links_pred): def mesure_accuracy(batch_true_links, axiom_links_pred):
r""" r"""
batch_axiom_links : (batch_size, ...) batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
axiom_links_pred : (batch_size, max_atoms_type_polarity) axiom_links_pred : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
""" """
correct_links = torch.ones(axiom_links_pred.size()) correct_links = torch.ones(axiom_links_pred.size())
correct_links[axiom_links_pred != batch_true_links] = 0 correct_links[axiom_links_pred != batch_true_links] = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment