From 4004d7c27c78870fd1f0cc754b7ee4b68e915bcd Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Tue, 28 Jun 2022 11:14:23 +0200 Subject: [PATCH] trainning --- Configuration/config.ini | 2 +- Linker/Linker.py | 7 +--- Linker/utils_linker.py | 87 ++++++++++++++++++---------------------- train.py | 2 +- 4 files changed, 42 insertions(+), 56 deletions(-) diff --git a/Configuration/config.ini b/Configuration/config.ini index e0d94a3..73f1743 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -4,7 +4,7 @@ transformers = 4.16.2 [DATASET_PARAMS] symbols_vocab_size=26 atom_vocab_size=18 -max_len_sentence=83 +max_len_sentence=290 max_atoms_in_sentence=875 max_atoms_in_one_type=324 diff --git a/Linker/Linker.py b/Linker/Linker.py index 2d79231..99524a2 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -147,11 +147,8 @@ class Linker(Module): sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) print(sentences_tokens) - atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"]) - atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids( - list(map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch))) - - num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence) + atoms_batch, atoms_polarity_batch, num_atoms_per_word = get_GOAL(self.max_len_sentence, self.max_atoms_in_sentence, df_axiom_links["Z"]) + atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py index c7c7a8d..1c0ab28 100644 --- a/Linker/utils_linker.py +++ b/Linker/utils_linker.py @@ -40,7 +40,6 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms """ atoms_batch = get_atoms_links_batch(batch_axiom_links) - atoms_batch = list(map(lambda sentence: sentence.split(" "), atoms_batch)) linking_plus_to_minus_all_types = [] for atom_type in list(atom_map_redux.keys()): # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity @@ -76,12 +75,12 @@ def category_to_atoms_axiom_links(category, categories_to_atoms): word, cat = category.split(':') return category_to_atoms_axiom_links(cat, categories_to_atoms) elif True in res: - return " " + category + return [category] else: category_cut = regex.match(regex_categories_axiom_links, category).groups() category_cut = [cat for cat in category_cut if cat is not None] for cat in category_cut: - categories_to_atoms += category_to_atoms_axiom_links(cat, "") + categories_to_atoms += category_to_atoms_axiom_links(cat, []) return categories_to_atoms @@ -94,15 +93,13 @@ def get_atoms_links_batch(category_batch): """ batch = [] for sentence in category_batch: - categories_to_atoms = "" + categories_to_atoms = [] for category in sentence: if category != "let" and not category.startswith("GOAL:"): - categories_to_atoms += category_to_atoms_axiom_links(category, "") - categories_to_atoms += " [SEP]" - categories_to_atoms = categories_to_atoms.lstrip() + categories_to_atoms += category_to_atoms_axiom_links(category, []) + categories_to_atoms.append("[SEP]") elif category.startswith("GOAL:"): - categories_to_atoms += category_to_atoms_axiom_links(category, "") - categories_to_atoms = categories_to_atoms.lstrip() + categories_to_atoms += category_to_atoms_axiom_links(category, []) batch.append(categories_to_atoms) return batch @@ -132,12 +129,12 @@ def category_to_atoms(category, categories_to_atoms): word, cat = category.split(':') return category_to_atoms(cat, categories_to_atoms) elif True in res: - return " " + category + return [category] else: category_cut = regex.match(regex_categories, category).groups() category_cut = [cat for cat in category_cut if cat is not None] for cat in category_cut: - categories_to_atoms += category_to_atoms(cat, "") + categories_to_atoms += category_to_atoms(cat, []) return categories_to_atoms @@ -150,17 +147,16 @@ def get_atoms_batch(category_batch): """ batch = [] for sentence in category_batch: - categories_to_atoms = "" + categories_to_atoms = [] for category in sentence: if category != "let": - categories_to_atoms += category_to_atoms(category, "") - categories_to_atoms += " [SEP]" - categories_to_atoms = categories_to_atoms.lstrip() + categories_to_atoms += category_to_atoms(category, []) + categories_to_atoms.append("[SEP]") batch.append(categories_to_atoms) return batch -print(" test for get atoms in categories on ['dr(0,s,np)', 'let']", get_atoms_batch([["dr(0,s,np)", "let"]])) +print(" test for get atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']", get_atoms_batch([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']])) # endregion @@ -319,36 +315,35 @@ print( # region get atoms and polarities with GOAL -def get_GOAL(max_atoms_in_sentence, categories_batch): +def get_GOAL(max_len_sentence, max_atoms_in_sentence, categories_batch): polarities = find_pos_neg_idexes(categories_batch) atoms_batch = get_atoms_batch(categories_batch) - atoms_batch_for_polarities = list( - map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch)) + num_atoms_batch = get_num_atoms_batch(categories_batch, max_len_sentence) for s_idx in range(len(atoms_batch)): for atom_type in list(atom_map_redux.keys()): - list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i] - and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))] - list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i] - and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))] + list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i] + and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))] + list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i] + and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))] while len(list_minus) != len(list_plus): if len(list_minus) > len(list_plus): - atoms_batch[s_idx] += " " + atom_type - atoms_batch_for_polarities[s_idx].append(atom_type) - polarities[s_idx].append(True) + atoms_batch[s_idx].insert(0, atom_type) + polarities[s_idx].insert(0, True) + num_atoms_batch[s_idx][0] += 1 else: - atoms_batch[s_idx] += " " + atom_type - atoms_batch_for_polarities[s_idx].append(atom_type) - polarities[s_idx].append(False) - list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i] - and atoms_batch_for_polarities[s_idx][i] == atom_type] - list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i] - and atoms_batch_for_polarities[s_idx][i] == atom_type] + atoms_batch[s_idx].insert(0, atom_type) + polarities[s_idx].insert(0, False) + num_atoms_batch[s_idx][0] += 1 + list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i] + and atoms_batch[s_idx][i] == atom_type] + list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i] + and atoms_batch[s_idx][i] == atom_type] return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], - max_len=max_atoms_in_sentence, padding_value=0) + max_len=max_atoms_in_sentence, padding_value=0), num_atoms_batch -print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)", "s"]])) +print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(5, 12, [["dr(0,s,np)", "s"]])) # endregion @@ -356,13 +351,10 @@ print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)", # region get idx for pos and neg def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence): - atoms_batch_for_polarities = list( - map(lambda sentence: sentence.split(" "), atoms_batch)) pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if - bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", - atoms_batch_for_polarities[s_idx][i])) and + bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and atoms_polarity_batch[s_idx][i]]) - for s_idx, sentence in enumerate(atoms_batch_for_polarities)], + for s_idx, sentence in enumerate(atoms_batch)], max_len=max_atoms_in_one_type // 2, padding_value=-1) for atom_type in list(atom_map_redux.keys())] @@ -370,24 +362,21 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_at def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence): - atoms_batch_for_polarities = list( - map(lambda sentence: sentence.split(" "), atoms_batch)) pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if - bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", - atoms_batch_for_polarities[s_idx][i])) and not + bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z",atoms_batch[s_idx][i])) and not atoms_polarity_batch[s_idx][i]]) - for s_idx, sentence in enumerate(atoms_batch_for_polarities)], + for s_idx, sentence in enumerate(atoms_batch)], max_len=max_atoms_in_one_type // 2, padding_value=-1) for atom_type in list(atom_map_redux.keys())] return torch.stack(pos_idx).permute(1, 0, 2) -print(" test for cut into pos neg on ['s np [SEP] s [SEP] np s s n n']", get_neg_idx(['s np [SEP] s [SEP] np s s n n'], +print(" test for cut into pos neg on ['dr(0,s,np)', 's']", get_neg_idx([['s', 's', 'np', 's', 'np', '[SEP]', 's', '[SEP]']], torch.as_tensor( - [[False, True, False, False, - False, False, True, True, - False, True, + [[True, True, False, False, + True, False, False, False, + False, False, False, False]]), 10, 50)) # endregion diff --git a/train.py b/train.py index 4721ed9..fdf3936 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,7 @@ from utils import read_csv_pgbar torch.cuda.empty_cache() batch_size = int(Configuration.modelTrainingConfig['batch_size']) -nb_sentences = batch_size * 4 +nb_sentences = batch_size * 800 epochs = int(Configuration.modelTrainingConfig['epoch']) file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv' -- GitLab