Skip to content
Snippets Groups Projects
Commit a874cd6d authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

it runs, some corrections needed next

parent e6ae31ff
Branches
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -90,13 +90,13 @@ class Linker(Module): ...@@ -90,13 +90,13 @@ class Linker(Module):
if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
atoms_batch_tokenized[s_idx][i] == self.atom_map[atom_type] and atoms_batch_tokenized[s_idx][i] == self.atom_map[atom_type] and
atoms_polarity_batch[s_idx][i])] + [torch.zeros(self.dim_embedding_atoms)]) atoms_polarity_batch[s_idx][i])] + [torch.zeros(self.dim_embedding_atoms)])
for s_idx in range(len(atoms_polarity_batch))], padding_value=19, max_len=self.max_atoms_in_one_type//2) for s_idx in range(len(atoms_polarity_batch))], padding_value=0, max_len=self.max_atoms_in_one_type//2)
neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
atoms_batch_tokenized[s_idx][i] == self.atom_map[atom_type] and atoms_batch_tokenized[s_idx][i] == self.atom_map[atom_type] and
not atoms_polarity_batch[s_idx][i])] + [torch.zeros(self.dim_embedding_atoms)]) not atoms_polarity_batch[s_idx][i])] + [torch.zeros(self.dim_embedding_atoms)])
for s_idx in range(len(atoms_polarity_batch))], padding_value=19, max_len=self.max_atoms_in_one_type//2) for s_idx in range(len(atoms_polarity_batch))], padding_value=0, max_len=self.max_atoms_in_one_type//2)
pos_encoding = self.pos_transformation(pos_encoding) pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding) neg_encoding = self.neg_transformation(neg_encoding)
...@@ -104,7 +104,7 @@ class Linker(Module): ...@@ -104,7 +104,7 @@ class Linker(Module):
weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
link_weights.append(sinkhorn(weights, iters=3)) link_weights.append(sinkhorn(weights, iters=3))
return torch.cat([link_weights[i].unsqueeze(0) for i in range(len(link_weights))]) return torch.stack(link_weights)
def eval_batch(self, batch, cross_entropy_loss): def eval_batch(self, batch, cross_entropy_loss):
batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
......
No preview for this file type
No preview for this file type
...@@ -27,11 +27,12 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): ...@@ -27,11 +27,12 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
linking_plus_to_minus = pad_sequence( linking_plus_to_minus = pad_sequence(
[torch.as_tensor([l_polarity_minus[s_idx].index(x) for i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long) [torch.as_tensor([l_polarity_minus[s_idx].index(x) for i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type//2, padding_value=0) for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type//2, padding_value=-1)
linking_plus_to_minus_all_types.append(linking_plus_to_minus) linking_plus_to_minus_all_types.append(linking_plus_to_minus)
return torch.cat([linking_plus_to_minus_all_types[i].unsqueeze(0) for i in range(len(linking_plus_to_minus_all_types))]) return torch.stack(linking_plus_to_minus_all_types)
def category_to_atoms_axiom_links(category, categories_to_atoms): def category_to_atoms_axiom_links(category, categories_to_atoms):
...@@ -97,7 +98,7 @@ def category_to_atoms_polarity(category, polarity): ...@@ -97,7 +98,7 @@ def category_to_atoms_polarity(category, polarity):
category_to_polarity = [] category_to_polarity = []
res = [bool(re.match(r''+atom_type+"_\d+", category)) for atom_type in atom_map.keys()] res = [bool(re.match(r''+atom_type+"_\d+", category)) for atom_type in atom_map.keys()]
if True in res or category.startswith("dia") or category.startswith("box"): if True in res or category.startswith("dia") or category.startswith("box"):
category_to_polarity.append(polarity) category_to_polarity.append(not polarity)
else: else:
# dr = / # dr = /
if category.startswith("dr"): if category.startswith("dr"):
...@@ -106,7 +107,7 @@ def category_to_atoms_polarity(category, polarity): ...@@ -106,7 +107,7 @@ def category_to_atoms_polarity(category, polarity):
left_side, right_side = category_cut[0], category_cut[1] left_side, right_side = category_cut[0], category_cut[1]
# for the left side # for the left side
category_to_polarity += category_to_atoms_polarity(left_side, polarity) category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
# for the right side # for the right side
res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()] res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
...@@ -129,7 +130,7 @@ def category_to_atoms_polarity(category, polarity): ...@@ -129,7 +130,7 @@ def category_to_atoms_polarity(category, polarity):
category_to_polarity += category_to_atoms_polarity(left_side, not polarity) category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
# for the right side # for the right side
category_to_polarity += category_to_atoms_polarity(right_side, polarity) category_to_polarity += category_to_atoms_polarity(right_side, not polarity)
return category_to_polarity return category_to_polarity
...@@ -145,13 +146,14 @@ def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch): ...@@ -145,13 +146,14 @@ def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch):
list_batch = [] list_batch = []
for sentence in atoms_batch: for sentence in atoms_batch:
list_atoms = [] list_atoms = []
polarity = False
for category in sentence: for category in sentence:
polarity = True
for at in category_to_atoms_polarity(category, polarity): for at in category_to_atoms_polarity(category, polarity):
list_atoms.append(at) list_atoms.append(at)
list_batch.append(torch.as_tensor(list_atoms)) list_batch.append(torch.as_tensor(list_atoms))
return pad_sequence([list_batch[i] for i in range(len(list_batch))], return pad_sequence([list_batch[i] for i in range(len(list_batch))],
max_len=max_atoms_in_sentence, padding_value=0) max_len=max_atoms_in_sentence, padding_value=0)
atoms_pol = find_pos_neg_idexes(10, [['dr(1,np_1,s_1)', 'dl(1,np_1,s_1)']])
print(find_pos_neg_idexes(9, [['dr(0,dl(0,dr(0,pp_52,np_53),dl(0,np_41,np_32)),dr(0,s_54,dia(1,box(1,pp_55))))', 'dr(0,dl(0,np_58,s_59),pp_55)']])) print(atoms_pol)
\ No newline at end of file print(get_axiom_links(10, atoms_pol, [['dr(1,np_1,s_1)', 'dl(1,np_1,s_1)']]))
No preview for this file type
...@@ -14,7 +14,7 @@ class SinkhornLoss(Module): ...@@ -14,7 +14,7 @@ class SinkhornLoss(Module):
super(SinkhornLoss, self).__init__() super(SinkhornLoss, self).__init__()
def forward(self, predictions, truths): def forward(self, predictions, truths):
return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean') return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1)
for link, perm in zip(predictions, truths)) for link, perm in zip(predictions, truths))
...@@ -25,7 +25,9 @@ def mesure_accuracy(linking_plus_to_minus, axiom_links_pred): ...@@ -25,7 +25,9 @@ def mesure_accuracy(linking_plus_to_minus, axiom_links_pred):
""" """
correct_links = torch.ones(axiom_links_pred.size()) correct_links = torch.ones(axiom_links_pred.size())
correct_links[axiom_links_pred != linking_plus_to_minus] = 0 correct_links[axiom_links_pred != linking_plus_to_minus] = 0
correct_links[linking_plus_to_minus == -1] = 1
num_correct_links = correct_links.sum().item() num_correct_links = correct_links.sum().item()
num_masked_atoms = len(linking_plus_to_minus[linking_plus_to_minus == -1])
# diviser par nombre de links # diviser par nombre de links
return num_correct_links/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2]) return (num_correct_links - num_masked_atoms)/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment