Skip to content
Snippets Groups Projects
Commit a8a529b4 authored by Caroline de Pourtalès's avatar Caroline de Pourtalès :speech_balloon:
Browse files

Merge branch 'work-on-goal' into 'word-cat-atom-embedding'

trainning

See merge request !4
parents 82051bbb 4004d7c2
No related branches found
No related tags found
1 merge request!4trainning
......@@ -4,7 +4,7 @@ transformers = 4.16.2
[DATASET_PARAMS]
symbols_vocab_size=26
atom_vocab_size=18
max_len_sentence=83
max_len_sentence=290
max_atoms_in_sentence=875
max_atoms_in_one_type=324
......
......@@ -147,11 +147,8 @@ class Linker(Module):
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
print(sentences_tokens)
atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"])
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(
list(map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch)))
num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence)
atoms_batch, atoms_polarity_batch, num_atoms_per_word = get_GOAL(self.max_len_sentence, self.max_atoms_in_sentence, df_axiom_links["Z"])
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
......
......@@ -40,7 +40,6 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
"""
atoms_batch = get_atoms_links_batch(batch_axiom_links)
atoms_batch = list(map(lambda sentence: sentence.split(" "), atoms_batch))
linking_plus_to_minus_all_types = []
for atom_type in list(atom_map_redux.keys()):
# filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
......@@ -76,12 +75,12 @@ def category_to_atoms_axiom_links(category, categories_to_atoms):
word, cat = category.split(':')
return category_to_atoms_axiom_links(cat, categories_to_atoms)
elif True in res:
return " " + category
return [category]
else:
category_cut = regex.match(regex_categories_axiom_links, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut:
categories_to_atoms += category_to_atoms_axiom_links(cat, "")
categories_to_atoms += category_to_atoms_axiom_links(cat, [])
return categories_to_atoms
......@@ -94,15 +93,13 @@ def get_atoms_links_batch(category_batch):
"""
batch = []
for sentence in category_batch:
categories_to_atoms = ""
categories_to_atoms = []
for category in sentence:
if category != "let" and not category.startswith("GOAL:"):
categories_to_atoms += category_to_atoms_axiom_links(category, "")
categories_to_atoms += " [SEP]"
categories_to_atoms = categories_to_atoms.lstrip()
categories_to_atoms += category_to_atoms_axiom_links(category, [])
categories_to_atoms.append("[SEP]")
elif category.startswith("GOAL:"):
categories_to_atoms += category_to_atoms_axiom_links(category, "")
categories_to_atoms = categories_to_atoms.lstrip()
categories_to_atoms += category_to_atoms_axiom_links(category, [])
batch.append(categories_to_atoms)
return batch
......@@ -132,12 +129,12 @@ def category_to_atoms(category, categories_to_atoms):
word, cat = category.split(':')
return category_to_atoms(cat, categories_to_atoms)
elif True in res:
return " " + category
return [category]
else:
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut:
categories_to_atoms += category_to_atoms(cat, "")
categories_to_atoms += category_to_atoms(cat, [])
return categories_to_atoms
......@@ -150,17 +147,16 @@ def get_atoms_batch(category_batch):
"""
batch = []
for sentence in category_batch:
categories_to_atoms = ""
categories_to_atoms = []
for category in sentence:
if category != "let":
categories_to_atoms += category_to_atoms(category, "")
categories_to_atoms += " [SEP]"
categories_to_atoms = categories_to_atoms.lstrip()
categories_to_atoms += category_to_atoms(category, [])
categories_to_atoms.append("[SEP]")
batch.append(categories_to_atoms)
return batch
print(" test for get atoms in categories on ['dr(0,s,np)', 'let']", get_atoms_batch([["dr(0,s,np)", "let"]]))
print(" test for get atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']", get_atoms_batch([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']]))
# endregion
......@@ -319,36 +315,35 @@ print(
# region get atoms and polarities with GOAL
def get_GOAL(max_atoms_in_sentence, categories_batch):
def get_GOAL(max_len_sentence, max_atoms_in_sentence, categories_batch):
polarities = find_pos_neg_idexes(categories_batch)
atoms_batch = get_atoms_batch(categories_batch)
atoms_batch_for_polarities = list(
map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch))
num_atoms_batch = get_num_atoms_batch(categories_batch, max_len_sentence)
for s_idx in range(len(atoms_batch)):
for atom_type in list(atom_map_redux.keys()):
list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))]
list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))]
list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))]
list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))]
while len(list_minus) != len(list_plus):
if len(list_minus) > len(list_plus):
atoms_batch[s_idx] += " " + atom_type
atoms_batch_for_polarities[s_idx].append(atom_type)
polarities[s_idx].append(True)
atoms_batch[s_idx].insert(0, atom_type)
polarities[s_idx].insert(0, True)
num_atoms_batch[s_idx][0] += 1
else:
atoms_batch[s_idx] += " " + atom_type
atoms_batch_for_polarities[s_idx].append(atom_type)
polarities[s_idx].append(False)
list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i]
and atoms_batch_for_polarities[s_idx][i] == atom_type]
list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i]
and atoms_batch_for_polarities[s_idx][i] == atom_type]
atoms_batch[s_idx].insert(0, atom_type)
polarities[s_idx].insert(0, False)
num_atoms_batch[s_idx][0] += 1
list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
and atoms_batch[s_idx][i] == atom_type]
list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
and atoms_batch[s_idx][i] == atom_type]
return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
max_len=max_atoms_in_sentence, padding_value=0)
max_len=max_atoms_in_sentence, padding_value=0), num_atoms_batch
print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)", "s"]]))
print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(5, 12, [["dr(0,s,np)", "s"]]))
# endregion
......@@ -356,13 +351,10 @@ print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)",
# region get idx for pos and neg
def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
atoms_batch_for_polarities = list(
map(lambda sentence: sentence.split(" "), atoms_batch))
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z",
atoms_batch_for_polarities[s_idx][i])) and
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch_for_polarities)],
for s_idx, sentence in enumerate(atoms_batch)],
max_len=max_atoms_in_one_type // 2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())]
......@@ -370,24 +362,21 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_at
def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
atoms_batch_for_polarities = list(
map(lambda sentence: sentence.split(" "), atoms_batch))
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z",
atoms_batch_for_polarities[s_idx][i])) and not
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z",atoms_batch[s_idx][i])) and not
atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch_for_polarities)],
for s_idx, sentence in enumerate(atoms_batch)],
max_len=max_atoms_in_one_type // 2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())]
return torch.stack(pos_idx).permute(1, 0, 2)
print(" test for cut into pos neg on ['s np [SEP] s [SEP] np s s n n']", get_neg_idx(['s np [SEP] s [SEP] np s s n n'],
print(" test for cut into pos neg on ['dr(0,s,np)', 's']", get_neg_idx([['s', 's', 'np', 's', 'np', '[SEP]', 's', '[SEP]']],
torch.as_tensor(
[[False, True, False, False,
False, False, True, True,
False, True,
[[True, True, False, False,
True, False, False, False,
False, False,
False, False]]), 10, 50))
# endregion
......@@ -5,7 +5,7 @@ from utils import read_csv_pgbar
torch.cuda.empty_cache()
batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 4
nb_sentences = batch_size * 800
epochs = int(Configuration.modelTrainingConfig['epoch'])
file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment