Skip to content
Snippets Groups Projects
Commit 9f49c6d7 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

it runs, some corrections needed next

parent a874cd6d
Branches
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
......@@ -61,11 +61,6 @@ class Linker(Module):
LayerNorm(self.dim_embedding_atoms, eps=1e-12)
)
def make_decoder_mask(self, atoms_batch):
decoder_attn_mask = torch.ones_like(atoms_batch, dtype=torch.float64)
decoder_attn_mask[atoms_batch.eq(self.padding_id)] = 0.0
return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_batch.shape[1], 1)
def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding):
'''
Parameters :
......
No preview for this file type
......@@ -9,6 +9,7 @@ from SuperTagger.utils import pad_sequence
regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
#########################################################################################
################################ Liste des atoms avc _i########################################
#########################################################################################
......@@ -26,17 +27,17 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
range(len(atoms_batch))]
linking_plus_to_minus = pad_sequence(
[torch.as_tensor([l_polarity_minus[s_idx].index(x) for i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type//2, padding_value=-1)
[torch.as_tensor([l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else -1 for i, x in
enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2, padding_value=-1)
linking_plus_to_minus_all_types.append(linking_plus_to_minus)
return torch.stack(linking_plus_to_minus_all_types)
def category_to_atoms_axiom_links(category, categories_to_atoms):
res = [bool(re.match(r''+atom_type+"_\d+", category)) for atom_type in atom_map.keys()]
res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
if True in res:
return [category]
else:
......@@ -56,13 +57,14 @@ def get_atoms_links_batch(category_batch):
batch.append(categories_to_atoms)
return batch
#########################################################################################
################################ Liste des atoms ########################################
#########################################################################################
def category_to_atoms(category, categories_to_atoms):
res = [bool(re.match(r''+atom_type+"_\d+", category)) for atom_type in atom_map.keys()]
res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
if True in res:
category = re.match(r'([a-zA-Z|_]+)_\d+', category).group(1)
return [category]
......@@ -96,9 +98,9 @@ def category_to_atoms_polarity(category, polarity):
Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
'''
category_to_polarity = []
res = [bool(re.match(r''+atom_type+"_\d+", category)) for atom_type in atom_map.keys()]
res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
if True in res or category.startswith("dia") or category.startswith("box"):
category_to_polarity.append(not polarity)
category_to_polarity.append(False)
else:
# dr = /
if category.startswith("dr"):
......@@ -106,15 +108,33 @@ def category_to_atoms_polarity(category, polarity):
category_cut = [cat for cat in category_cut if cat is not None]
left_side, right_side = category_cut[0], category_cut[1]
# for the left side
category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
# for the right side
res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
if True in res or right_side.startswith("dia") or right_side.startswith("box"):
category_to_polarity.append(polarity)
else :
category_to_polarity += category_to_atoms_polarity(right_side, not polarity)
if polarity == True:
# for the left side : normal
res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
if True in res or left_side.startswith("dia") or left_side.startswith("box"):
category_to_polarity.append(False)
else:
category_to_polarity += category_to_atoms_polarity(left_side, True)
# for the right side : change polarity for next right formula
res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
if True in res or right_side.startswith("dia") or right_side.startswith("box"):
category_to_polarity.append(True)
else:
category_to_polarity += category_to_atoms_polarity(right_side, False)
else:
# for the left side
res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
if True in res or left_side.startswith("dia") or left_side.startswith("box"):
category_to_polarity.append(True)
else:
category_to_polarity += category_to_atoms_polarity(left_side, False)
# for the right side : change polarity for next right formula
res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
if True in res or right_side.startswith("dia") or right_side.startswith("box"):
category_to_polarity.append(False)
else:
category_to_polarity += category_to_atoms_polarity(right_side, True)
# dl = \
elif category.startswith("dl"):
......@@ -122,15 +142,33 @@ def category_to_atoms_polarity(category, polarity):
category_cut = [cat for cat in category_cut if cat is not None]
left_side, right_side = category_cut[0], category_cut[1]
# for the left side
res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
if True in res or left_side.startswith("dia") or left_side.startswith("box"):
category_to_polarity.append(polarity)
else :
category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
# for the right side
category_to_polarity += category_to_atoms_polarity(right_side, not polarity)
if polarity == True:
# for the left side : change polarity
res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
if True in res or left_side.startswith("dia") or left_side.startswith("box"):
category_to_polarity.append(True)
else:
category_to_polarity += category_to_atoms_polarity(left_side, False)
# for the right side : normal
res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
if True in res or right_side.startswith("dia") or right_side.startswith("box"):
category_to_polarity.append(False)
else:
category_to_polarity += category_to_atoms_polarity(right_side, True)
else:
# for the left side
res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
if True in res or left_side.startswith("dia") or left_side.startswith("box"):
category_to_polarity.append(False)
else:
category_to_polarity += category_to_atoms_polarity(left_side, True)
# for the right side
res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
if True in res or right_side.startswith("dia") or right_side.startswith("box"):
category_to_polarity.append(True)
else:
category_to_polarity += category_to_atoms_polarity(right_side, False)
return category_to_polarity
......@@ -147,13 +185,8 @@ def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch):
for sentence in atoms_batch:
list_atoms = []
for category in sentence:
polarity = True
for at in category_to_atoms_polarity(category, polarity):
for at in category_to_atoms_polarity(category, True):
list_atoms.append(at)
list_batch.append(torch.as_tensor(list_atoms))
return pad_sequence([list_batch[i] for i in range(len(list_batch))],
max_len=max_atoms_in_sentence, padding_value=0)
atoms_pol = find_pos_neg_idexes(10, [['dr(1,np_1,s_1)', 'dl(1,np_1,s_1)']])
print(atoms_pol)
print(get_axiom_links(10, atoms_pol, [['dr(1,np_1,s_1)', 'dl(1,np_1,s_1)']]))
......@@ -18,16 +18,16 @@ class SinkhornLoss(Module):
for link, perm in zip(predictions, truths))
def mesure_accuracy(linking_plus_to_minus, axiom_links_pred):
def mesure_accuracy(batch_true_links, axiom_links_pred):
r"""
batch_axiom_links : (batch_size, ...)
axiom_links_pred : (batch_size, max_atoms_type_polarity)
"""
correct_links = torch.ones(axiom_links_pred.size())
correct_links[axiom_links_pred != linking_plus_to_minus] = 0
correct_links[linking_plus_to_minus == -1] = 1
correct_links[axiom_links_pred != batch_true_links] = 0
correct_links[batch_true_links == -1] = 1
num_correct_links = correct_links.sum().item()
num_masked_atoms = len(linking_plus_to_minus[linking_plus_to_minus == -1])
num_masked_atoms = len(batch_true_links[batch_true_links == -1])
# diviser par nombre de links
return (num_correct_links - num_masked_atoms)/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms)
......@@ -53,9 +53,11 @@ print("atoms_polarity_batch", atoms_polarity_batch.shape)
truth_links_batch = get_axiom_links(max_atoms_in_one_type, atoms_polarity_batch, df_axiom_links["sub_tree"])
print("truth_links_batch", truth_links_batch.permute(1, 0, 2).shape)
print(" truth_links_batch example on first sentence class cl_r", truth_links_batch[0][0])
sentences_batch = df_axiom_links["Sentences"]
# Construction tensor dataset
dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch.permute(1, 0, 2))
# Calculate the number of samples to include in each set.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment