diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py index b8dbf8c4083b66786bc43e10824e288e2c65aeff..281a7ab2dc74827c3ebee38780c912d5ed64a38a 100644 --- a/SuperTagger/Linker/Linker.py +++ b/SuperTagger/Linker/Linker.py @@ -61,11 +61,6 @@ class Linker(Module): LayerNorm(self.dim_embedding_atoms, eps=1e-12) ) - def make_decoder_mask(self, atoms_batch): - decoder_attn_mask = torch.ones_like(atoms_batch, dtype=torch.float64) - decoder_attn_mask[atoms_batch.eq(self.padding_id)] = 0.0 - return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_batch.shape[1], 1) - def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding): ''' Parameters : diff --git a/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc b/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc index 7f42ae68119786d9252d2c8eac5eb9c7a7083377..b13bf57006d2d3eb78625e40b86218f03e27798c 100644 Binary files a/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc and b/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc differ diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py index 898a921d77b96113a50513ccb94502e368a7de60..3f8e892a7f693fc58dd6754b11a114091d364480 100644 --- a/SuperTagger/Linker/utils.py +++ b/SuperTagger/Linker/utils.py @@ -9,6 +9,7 @@ from SuperTagger.utils import pad_sequence regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)' + ######################################################################################### ################################ Liste des atoms avc _i######################################## ######################################################################################### @@ -26,17 +27,17 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): range(len(atoms_batch))] linking_plus_to_minus = pad_sequence( - [torch.as_tensor([l_polarity_minus[s_idx].index(x) for i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long) - for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type//2, padding_value=-1) + [torch.as_tensor([l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else -1 for i, x in + enumerate(l_polarity_plus[s_idx])], dtype=torch.long) + for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2, padding_value=-1) linking_plus_to_minus_all_types.append(linking_plus_to_minus) return torch.stack(linking_plus_to_minus_all_types) - def category_to_atoms_axiom_links(category, categories_to_atoms): - res = [bool(re.match(r''+atom_type+"_\d+", category)) for atom_type in atom_map.keys()] + res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()] if True in res: return [category] else: @@ -56,13 +57,14 @@ def get_atoms_links_batch(category_batch): batch.append(categories_to_atoms) return batch + ######################################################################################### ################################ Liste des atoms ######################################## ######################################################################################### def category_to_atoms(category, categories_to_atoms): - res = [bool(re.match(r''+atom_type+"_\d+", category)) for atom_type in atom_map.keys()] + res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()] if True in res: category = re.match(r'([a-zA-Z|_]+)_\d+', category).group(1) return [category] @@ -96,9 +98,9 @@ def category_to_atoms_polarity(category, polarity): Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes ''' category_to_polarity = [] - res = [bool(re.match(r''+atom_type+"_\d+", category)) for atom_type in atom_map.keys()] + res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()] if True in res or category.startswith("dia") or category.startswith("box"): - category_to_polarity.append(not polarity) + category_to_polarity.append(False) else: # dr = / if category.startswith("dr"): @@ -106,15 +108,33 @@ def category_to_atoms_polarity(category, polarity): category_cut = [cat for cat in category_cut if cat is not None] left_side, right_side = category_cut[0], category_cut[1] - # for the left side - category_to_polarity += category_to_atoms_polarity(left_side, not polarity) - - # for the right side - res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()] - if True in res or right_side.startswith("dia") or right_side.startswith("box"): - category_to_polarity.append(polarity) - else : - category_to_polarity += category_to_atoms_polarity(right_side, not polarity) + if polarity == True: + # for the left side : normal + res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()] + if True in res or left_side.startswith("dia") or left_side.startswith("box"): + category_to_polarity.append(False) + else: + category_to_polarity += category_to_atoms_polarity(left_side, True) + # for the right side : change polarity for next right formula + res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()] + if True in res or right_side.startswith("dia") or right_side.startswith("box"): + category_to_polarity.append(True) + else: + category_to_polarity += category_to_atoms_polarity(right_side, False) + + else: + # for the left side + res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()] + if True in res or left_side.startswith("dia") or left_side.startswith("box"): + category_to_polarity.append(True) + else: + category_to_polarity += category_to_atoms_polarity(left_side, False) + # for the right side : change polarity for next right formula + res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()] + if True in res or right_side.startswith("dia") or right_side.startswith("box"): + category_to_polarity.append(False) + else: + category_to_polarity += category_to_atoms_polarity(right_side, True) # dl = \ elif category.startswith("dl"): @@ -122,15 +142,33 @@ def category_to_atoms_polarity(category, polarity): category_cut = [cat for cat in category_cut if cat is not None] left_side, right_side = category_cut[0], category_cut[1] - # for the left side - res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()] - if True in res or left_side.startswith("dia") or left_side.startswith("box"): - category_to_polarity.append(polarity) - else : - category_to_polarity += category_to_atoms_polarity(left_side, not polarity) - - # for the right side - category_to_polarity += category_to_atoms_polarity(right_side, not polarity) + if polarity == True: + # for the left side : change polarity + res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()] + if True in res or left_side.startswith("dia") or left_side.startswith("box"): + category_to_polarity.append(True) + else: + category_to_polarity += category_to_atoms_polarity(left_side, False) + # for the right side : normal + res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()] + if True in res or right_side.startswith("dia") or right_side.startswith("box"): + category_to_polarity.append(False) + else: + category_to_polarity += category_to_atoms_polarity(right_side, True) + + else: + # for the left side + res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()] + if True in res or left_side.startswith("dia") or left_side.startswith("box"): + category_to_polarity.append(False) + else: + category_to_polarity += category_to_atoms_polarity(left_side, True) + # for the right side + res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()] + if True in res or right_side.startswith("dia") or right_side.startswith("box"): + category_to_polarity.append(True) + else: + category_to_polarity += category_to_atoms_polarity(right_side, False) return category_to_polarity @@ -147,13 +185,8 @@ def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch): for sentence in atoms_batch: list_atoms = [] for category in sentence: - polarity = True - for at in category_to_atoms_polarity(category, polarity): + for at in category_to_atoms_polarity(category, True): list_atoms.append(at) list_batch.append(torch.as_tensor(list_atoms)) return pad_sequence([list_batch[i] for i in range(len(list_batch))], max_len=max_atoms_in_sentence, padding_value=0) - -atoms_pol = find_pos_neg_idexes(10, [['dr(1,np_1,s_1)', 'dl(1,np_1,s_1)']]) -print(atoms_pol) -print(get_axiom_links(10, atoms_pol, [['dr(1,np_1,s_1)', 'dl(1,np_1,s_1)']])) diff --git a/SuperTagger/eval.py b/SuperTagger/eval.py index 9bfdc85fd5ccf485364c808e9d6410da88f2f391..b287d4b18a382826db7d9695ae566d0fd6df0224 100644 --- a/SuperTagger/eval.py +++ b/SuperTagger/eval.py @@ -18,16 +18,16 @@ class SinkhornLoss(Module): for link, perm in zip(predictions, truths)) -def mesure_accuracy(linking_plus_to_minus, axiom_links_pred): +def mesure_accuracy(batch_true_links, axiom_links_pred): r""" batch_axiom_links : (batch_size, ...) axiom_links_pred : (batch_size, max_atoms_type_polarity) """ correct_links = torch.ones(axiom_links_pred.size()) - correct_links[axiom_links_pred != linking_plus_to_minus] = 0 - correct_links[linking_plus_to_minus == -1] = 1 + correct_links[axiom_links_pred != batch_true_links] = 0 + correct_links[batch_true_links == -1] = 1 num_correct_links = correct_links.sum().item() - num_masked_atoms = len(linking_plus_to_minus[linking_plus_to_minus == -1]) + num_masked_atoms = len(batch_true_links[batch_true_links == -1]) # diviser par nombre de links return (num_correct_links - num_masked_atoms)/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms) diff --git a/train.py b/train.py index fced23bdaadd03c2ede10dac3919b65e78584657..05d223f6e3b9710d18cf2f6984e36f3e257209c7 100644 --- a/train.py +++ b/train.py @@ -53,9 +53,11 @@ print("atoms_polarity_batch", atoms_polarity_batch.shape) truth_links_batch = get_axiom_links(max_atoms_in_one_type, atoms_polarity_batch, df_axiom_links["sub_tree"]) print("truth_links_batch", truth_links_batch.permute(1, 0, 2).shape) +print(" truth_links_batch example on first sentence class cl_r", truth_links_batch[0][0]) sentences_batch = df_axiom_links["Sentences"] +# Construction tensor dataset dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch.permute(1, 0, 2)) # Calculate the number of samples to include in each set.