Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision

Target

Select target project
  • pnria/global-helper/deepgrail-linker
1 result
Select Git revision
Show changes
Commits on Source (2)
......@@ -4,7 +4,7 @@ transformers = 4.16.2
[DATASET_PARAMS]
symbols_vocab_size=26
atom_vocab_size=18
max_len_sentence=83
max_len_sentence=290
max_atoms_in_sentence=875
max_atoms_in_one_type=324
......
......@@ -147,11 +147,8 @@ class Linker(Module):
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
print(sentences_tokens)
atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"])
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(
list(map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch)))
num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence)
atoms_batch, atoms_polarity_batch, num_atoms_per_word = get_GOAL(self.max_len_sentence, self.max_atoms_in_sentence, df_axiom_links["Z"])
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
......
......@@ -40,7 +40,6 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
"""
atoms_batch = get_atoms_links_batch(batch_axiom_links)
atoms_batch = list(map(lambda sentence: sentence.split(" "), atoms_batch))
linking_plus_to_minus_all_types = []
for atom_type in list(atom_map_redux.keys()):
# filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
......@@ -76,12 +75,12 @@ def category_to_atoms_axiom_links(category, categories_to_atoms):
word, cat = category.split(':')
return category_to_atoms_axiom_links(cat, categories_to_atoms)
elif True in res:
return " " + category
return [category]
else:
category_cut = regex.match(regex_categories_axiom_links, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut:
categories_to_atoms += category_to_atoms_axiom_links(cat, "")
categories_to_atoms += category_to_atoms_axiom_links(cat, [])
return categories_to_atoms
......@@ -94,15 +93,13 @@ def get_atoms_links_batch(category_batch):
"""
batch = []
for sentence in category_batch:
categories_to_atoms = ""
categories_to_atoms = []
for category in sentence:
if category != "let" and not category.startswith("GOAL:"):
categories_to_atoms += category_to_atoms_axiom_links(category, "")
categories_to_atoms += " [SEP]"
categories_to_atoms = categories_to_atoms.lstrip()
categories_to_atoms += category_to_atoms_axiom_links(category, [])
categories_to_atoms.append("[SEP]")
elif category.startswith("GOAL:"):
categories_to_atoms += category_to_atoms_axiom_links(category, "")
categories_to_atoms = categories_to_atoms.lstrip()
categories_to_atoms += category_to_atoms_axiom_links(category, [])
batch.append(categories_to_atoms)
return batch
......@@ -132,12 +129,12 @@ def category_to_atoms(category, categories_to_atoms):
word, cat = category.split(':')
return category_to_atoms(cat, categories_to_atoms)
elif True in res:
return " " + category
return [category]
else:
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut:
categories_to_atoms += category_to_atoms(cat, "")
categories_to_atoms += category_to_atoms(cat, [])
return categories_to_atoms
......@@ -150,17 +147,16 @@ def get_atoms_batch(category_batch):
"""
batch = []
for sentence in category_batch:
categories_to_atoms = ""
categories_to_atoms = []
for category in sentence:
if category != "let":
categories_to_atoms += category_to_atoms(category, "")
categories_to_atoms += " [SEP]"
categories_to_atoms = categories_to_atoms.lstrip()
categories_to_atoms += category_to_atoms(category, [])
categories_to_atoms.append("[SEP]")
batch.append(categories_to_atoms)
return batch
print(" test for get atoms in categories on ['dr(0,s,np)', 'let']", get_atoms_batch([["dr(0,s,np)", "let"]]))
print(" test for get atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']", get_atoms_batch([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']]))
# endregion
......@@ -319,36 +315,35 @@ print(
# region get atoms and polarities with GOAL
def get_GOAL(max_atoms_in_sentence, categories_batch):
def get_GOAL(max_len_sentence, max_atoms_in_sentence, categories_batch):
polarities = find_pos_neg_idexes(categories_batch)
atoms_batch = get_atoms_batch(categories_batch)
atoms_batch_for_polarities = list(
map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch))
num_atoms_batch = get_num_atoms_batch(categories_batch, max_len_sentence)
for s_idx in range(len(atoms_batch)):
for atom_type in list(atom_map_redux.keys()):
list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))]
list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))]
list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))]
list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))]
while len(list_minus) != len(list_plus):
if len(list_minus) > len(list_plus):
atoms_batch[s_idx] += " " + atom_type
atoms_batch_for_polarities[s_idx].append(atom_type)
polarities[s_idx].append(True)
atoms_batch[s_idx].insert(0, atom_type)
polarities[s_idx].insert(0, True)
num_atoms_batch[s_idx][0] += 1
else:
atoms_batch[s_idx] += " " + atom_type
atoms_batch_for_polarities[s_idx].append(atom_type)
polarities[s_idx].append(False)
list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i]
and atoms_batch_for_polarities[s_idx][i] == atom_type]
list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i]
and atoms_batch_for_polarities[s_idx][i] == atom_type]
atoms_batch[s_idx].insert(0, atom_type)
polarities[s_idx].insert(0, False)
num_atoms_batch[s_idx][0] += 1
list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
and atoms_batch[s_idx][i] == atom_type]
list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
and atoms_batch[s_idx][i] == atom_type]
return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
max_len=max_atoms_in_sentence, padding_value=0)
max_len=max_atoms_in_sentence, padding_value=0), num_atoms_batch
print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)", "s"]]))
print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(5, 12, [["dr(0,s,np)", "s"]]))
# endregion
......@@ -356,13 +351,10 @@ print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)",
# region get idx for pos and neg
def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
atoms_batch_for_polarities = list(
map(lambda sentence: sentence.split(" "), atoms_batch))
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z",
atoms_batch_for_polarities[s_idx][i])) and
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch_for_polarities)],
for s_idx, sentence in enumerate(atoms_batch)],
max_len=max_atoms_in_one_type // 2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())]
......@@ -370,24 +362,21 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_at
def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
atoms_batch_for_polarities = list(
map(lambda sentence: sentence.split(" "), atoms_batch))
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z",
atoms_batch_for_polarities[s_idx][i])) and not
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z",atoms_batch[s_idx][i])) and not
atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch_for_polarities)],
for s_idx, sentence in enumerate(atoms_batch)],
max_len=max_atoms_in_one_type // 2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())]
return torch.stack(pos_idx).permute(1, 0, 2)
print(" test for cut into pos neg on ['s np [SEP] s [SEP] np s s n n']", get_neg_idx(['s np [SEP] s [SEP] np s s n n'],
print(" test for cut into pos neg on ['dr(0,s,np)', 's']", get_neg_idx([['s', 's', 'np', 's', 'np', '[SEP]', 's', '[SEP]']],
torch.as_tensor(
[[False, True, False, False,
False, False, True, True,
False, True,
[[True, True, False, False,
True, False, False, False,
False, False,
False, False]]), 10, 50))
# endregion
......@@ -5,7 +5,7 @@ from utils import read_csv_pgbar
torch.cuda.empty_cache()
batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 4
nb_sentences = batch_size * 800
epochs = int(Configuration.modelTrainingConfig['epoch'])
file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
......