Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • main
1 result

Target

Select target project
  • pnria/global-helper/deepgrail-linker
1 result
Select Git revision
  • main
1 result
Show changes
Commits on Source (2)
...@@ -4,7 +4,7 @@ transformers = 4.16.2 ...@@ -4,7 +4,7 @@ transformers = 4.16.2
[DATASET_PARAMS] [DATASET_PARAMS]
symbols_vocab_size=26 symbols_vocab_size=26
atom_vocab_size=18 atom_vocab_size=18
max_len_sentence=83 max_len_sentence=290
max_atoms_in_sentence=875 max_atoms_in_sentence=875
max_atoms_in_one_type=324 max_atoms_in_one_type=324
......
...@@ -147,11 +147,8 @@ class Linker(Module): ...@@ -147,11 +147,8 @@ class Linker(Module):
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
print(sentences_tokens) print(sentences_tokens)
atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"]) atoms_batch, atoms_polarity_batch, num_atoms_per_word = get_GOAL(self.max_len_sentence, self.max_atoms_in_sentence, df_axiom_links["Z"])
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids( atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
list(map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch)))
num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence)
pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
......
...@@ -40,7 +40,6 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): ...@@ -40,7 +40,6 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
""" """
atoms_batch = get_atoms_links_batch(batch_axiom_links) atoms_batch = get_atoms_links_batch(batch_axiom_links)
atoms_batch = list(map(lambda sentence: sentence.split(" "), atoms_batch))
linking_plus_to_minus_all_types = [] linking_plus_to_minus_all_types = []
for atom_type in list(atom_map_redux.keys()): for atom_type in list(atom_map_redux.keys()):
# filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
...@@ -76,12 +75,12 @@ def category_to_atoms_axiom_links(category, categories_to_atoms): ...@@ -76,12 +75,12 @@ def category_to_atoms_axiom_links(category, categories_to_atoms):
word, cat = category.split(':') word, cat = category.split(':')
return category_to_atoms_axiom_links(cat, categories_to_atoms) return category_to_atoms_axiom_links(cat, categories_to_atoms)
elif True in res: elif True in res:
return " " + category return [category]
else: else:
category_cut = regex.match(regex_categories_axiom_links, category).groups() category_cut = regex.match(regex_categories_axiom_links, category).groups()
category_cut = [cat for cat in category_cut if cat is not None] category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut: for cat in category_cut:
categories_to_atoms += category_to_atoms_axiom_links(cat, "") categories_to_atoms += category_to_atoms_axiom_links(cat, [])
return categories_to_atoms return categories_to_atoms
...@@ -94,15 +93,13 @@ def get_atoms_links_batch(category_batch): ...@@ -94,15 +93,13 @@ def get_atoms_links_batch(category_batch):
""" """
batch = [] batch = []
for sentence in category_batch: for sentence in category_batch:
categories_to_atoms = "" categories_to_atoms = []
for category in sentence: for category in sentence:
if category != "let" and not category.startswith("GOAL:"): if category != "let" and not category.startswith("GOAL:"):
categories_to_atoms += category_to_atoms_axiom_links(category, "") categories_to_atoms += category_to_atoms_axiom_links(category, [])
categories_to_atoms += " [SEP]" categories_to_atoms.append("[SEP]")
categories_to_atoms = categories_to_atoms.lstrip()
elif category.startswith("GOAL:"): elif category.startswith("GOAL:"):
categories_to_atoms += category_to_atoms_axiom_links(category, "") categories_to_atoms += category_to_atoms_axiom_links(category, [])
categories_to_atoms = categories_to_atoms.lstrip()
batch.append(categories_to_atoms) batch.append(categories_to_atoms)
return batch return batch
...@@ -132,12 +129,12 @@ def category_to_atoms(category, categories_to_atoms): ...@@ -132,12 +129,12 @@ def category_to_atoms(category, categories_to_atoms):
word, cat = category.split(':') word, cat = category.split(':')
return category_to_atoms(cat, categories_to_atoms) return category_to_atoms(cat, categories_to_atoms)
elif True in res: elif True in res:
return " " + category return [category]
else: else:
category_cut = regex.match(regex_categories, category).groups() category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None] category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut: for cat in category_cut:
categories_to_atoms += category_to_atoms(cat, "") categories_to_atoms += category_to_atoms(cat, [])
return categories_to_atoms return categories_to_atoms
...@@ -150,17 +147,16 @@ def get_atoms_batch(category_batch): ...@@ -150,17 +147,16 @@ def get_atoms_batch(category_batch):
""" """
batch = [] batch = []
for sentence in category_batch: for sentence in category_batch:
categories_to_atoms = "" categories_to_atoms = []
for category in sentence: for category in sentence:
if category != "let": if category != "let":
categories_to_atoms += category_to_atoms(category, "") categories_to_atoms += category_to_atoms(category, [])
categories_to_atoms += " [SEP]" categories_to_atoms.append("[SEP]")
categories_to_atoms = categories_to_atoms.lstrip()
batch.append(categories_to_atoms) batch.append(categories_to_atoms)
return batch return batch
print(" test for get atoms in categories on ['dr(0,s,np)', 'let']", get_atoms_batch([["dr(0,s,np)", "let"]])) print(" test for get atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']", get_atoms_batch([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']]))
# endregion # endregion
...@@ -319,36 +315,35 @@ print( ...@@ -319,36 +315,35 @@ print(
# region get atoms and polarities with GOAL # region get atoms and polarities with GOAL
def get_GOAL(max_atoms_in_sentence, categories_batch): def get_GOAL(max_len_sentence, max_atoms_in_sentence, categories_batch):
polarities = find_pos_neg_idexes(categories_batch) polarities = find_pos_neg_idexes(categories_batch)
atoms_batch = get_atoms_batch(categories_batch) atoms_batch = get_atoms_batch(categories_batch)
atoms_batch_for_polarities = list( num_atoms_batch = get_num_atoms_batch(categories_batch, max_len_sentence)
map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch))
for s_idx in range(len(atoms_batch)): for s_idx in range(len(atoms_batch)):
for atom_type in list(atom_map_redux.keys()): for atom_type in list(atom_map_redux.keys()):
list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i] list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))] and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))]
list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i] list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))] and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))]
while len(list_minus) != len(list_plus): while len(list_minus) != len(list_plus):
if len(list_minus) > len(list_plus): if len(list_minus) > len(list_plus):
atoms_batch[s_idx] += " " + atom_type atoms_batch[s_idx].insert(0, atom_type)
atoms_batch_for_polarities[s_idx].append(atom_type) polarities[s_idx].insert(0, True)
polarities[s_idx].append(True) num_atoms_batch[s_idx][0] += 1
else: else:
atoms_batch[s_idx] += " " + atom_type atoms_batch[s_idx].insert(0, atom_type)
atoms_batch_for_polarities[s_idx].append(atom_type) polarities[s_idx].insert(0, False)
polarities[s_idx].append(False) num_atoms_batch[s_idx][0] += 1
list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i] list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
and atoms_batch_for_polarities[s_idx][i] == atom_type] and atoms_batch[s_idx][i] == atom_type]
list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i] list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
and atoms_batch_for_polarities[s_idx][i] == atom_type] and atoms_batch[s_idx][i] == atom_type]
return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
max_len=max_atoms_in_sentence, padding_value=0) max_len=max_atoms_in_sentence, padding_value=0), num_atoms_batch
print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)", "s"]])) print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(5, 12, [["dr(0,s,np)", "s"]]))
# endregion # endregion
...@@ -356,13 +351,10 @@ print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)", ...@@ -356,13 +351,10 @@ print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)",
# region get idx for pos and neg # region get idx for pos and neg
def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence): def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
atoms_batch_for_polarities = list(
map(lambda sentence: sentence.split(" "), atoms_batch))
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
atoms_batch_for_polarities[s_idx][i])) and
atoms_polarity_batch[s_idx][i]]) atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch_for_polarities)], for s_idx, sentence in enumerate(atoms_batch)],
max_len=max_atoms_in_one_type // 2, padding_value=-1) max_len=max_atoms_in_one_type // 2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())] for atom_type in list(atom_map_redux.keys())]
...@@ -370,24 +362,21 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_at ...@@ -370,24 +362,21 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_at
def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence): def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
atoms_batch_for_polarities = list(
map(lambda sentence: sentence.split(" "), atoms_batch))
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z",atoms_batch[s_idx][i])) and not
atoms_batch_for_polarities[s_idx][i])) and not
atoms_polarity_batch[s_idx][i]]) atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch_for_polarities)], for s_idx, sentence in enumerate(atoms_batch)],
max_len=max_atoms_in_one_type // 2, padding_value=-1) max_len=max_atoms_in_one_type // 2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())] for atom_type in list(atom_map_redux.keys())]
return torch.stack(pos_idx).permute(1, 0, 2) return torch.stack(pos_idx).permute(1, 0, 2)
print(" test for cut into pos neg on ['s np [SEP] s [SEP] np s s n n']", get_neg_idx(['s np [SEP] s [SEP] np s s n n'], print(" test for cut into pos neg on ['dr(0,s,np)', 's']", get_neg_idx([['s', 's', 'np', 's', 'np', '[SEP]', 's', '[SEP]']],
torch.as_tensor( torch.as_tensor(
[[False, True, False, False, [[True, True, False, False,
False, False, True, True, True, False, False, False,
False, True, False, False,
False, False]]), 10, 50)) False, False]]), 10, 50))
# endregion # endregion
...@@ -5,7 +5,7 @@ from utils import read_csv_pgbar ...@@ -5,7 +5,7 @@ from utils import read_csv_pgbar
torch.cuda.empty_cache() torch.cuda.empty_cache()
batch_size = int(Configuration.modelTrainingConfig['batch_size']) batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 4 nb_sentences = batch_size * 800
epochs = int(Configuration.modelTrainingConfig['epoch']) epochs = int(Configuration.modelTrainingConfig['epoch'])
file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv' file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
......