diff --git a/Configuration/config.ini b/Configuration/config.ini index fedb31786ee8e94c9842dda2a24209c43827183e..f1e5c0d836018139a417fe0f5d82bfb9fa5f7884 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -12,19 +12,19 @@ max_atoms_in_one_type=324 dim_encoder = 768 [MODEL_LINKER] -nhead=16 +nhead=8 dim_emb_atom = 256 -dim_feedforward_transformer = 512 +dim_feedforward_transformer = 768 num_layers=3 dim_cat_inter=512 dim_cat_out=256 dim_intermediate_FFN=128 -dim_pre_sinkhorn_transfo=64 +dim_pre_sinkhorn_transfo=32 dropout=0.1 sinkhorn_iters=5 [MODEL_TRAINING] -batch_size=16 +batch_size=32 epoch=30 seed_val=42 learning_rate=2e-3 \ No newline at end of file diff --git a/Linker/Linker.py b/Linker/Linker.py index 3012c7e1e58f9693d7556acbccae6611e85b2ec0..9d44b544e668b5fabdd79c890600f88c76116686 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -9,7 +9,7 @@ import time import torch import torch.nn.functional as F from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout, TransformerEncoderLayer, TransformerEncoder, \ - Embedding + Embedding, GELU from torch.optim import AdamW from torch.optim.lr_scheduler import StepLR from torch.utils.data import TensorDataset, random_split @@ -72,7 +72,8 @@ class Linker(Module): self.dim_feedforward_transformer = int(Configuration.modelLinkerConfig['dim_feedforward_transformer']) self.num_layers = int(Configuration.modelLinkerConfig['num_layers']) # torch cat - self.dim_cat_inter = int(Configuration.modelLinkerConfig['dim_cat_out']) + dropout = float(Configuration.modelLinkerConfig['dropout']) + self.dim_cat_inter = int(Configuration.modelLinkerConfig['dim_cat_inter']) self.dim_cat_out = int(Configuration.modelLinkerConfig['dim_cat_out']) dim_intermediate_FFN = int(Configuration.modelLinkerConfig['dim_intermediate_FFN']) dim_pre_sinkhorn_transfo = int(Configuration.modelLinkerConfig['dim_pre_sinkhorn_transfo']) @@ -107,7 +108,9 @@ class Linker(Module): # Concatenation with word embedding dim_cat = dim_encoder + self.dim_emb_atom self.linker_encoder = Sequential( - FFN(dim_cat, self.dim_cat_inter, 0.1, d_out=self.dim_cat_out), + Linear(dim_cat, self.dim_cat_out), + GELU(), + Dropout(dropout), LayerNorm(self.dim_cat_out, eps=1e-8) ) @@ -142,11 +145,8 @@ class Linker(Module): sentences_batch = df_axiom_links["X"].str.strip().tolist() sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) - atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"]) - atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids( - list(map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch))) - - num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence) + atoms_batch, atoms_polarity_batch, num_atoms_per_word = get_GOAL(self.max_len_sentence, self.max_atoms_in_sentence, df_axiom_links) + atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py index 8bb55d1673f82ce8863b8177ce537e26249d52cb..92b41d589d6d8f8b250f3f6882d1c2f92cce046f 100644 --- a/Linker/utils_linker.py +++ b/Linker/utils_linker.py @@ -1,4 +1,6 @@ import re + +import pandas as pd import regex import torch from torch.nn import Sequential, Linear, Dropout, GELU @@ -40,7 +42,6 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms """ atoms_batch = get_atoms_links_batch(batch_axiom_links) - atoms_batch = list(map(lambda sentence: sentence.split(" "), atoms_batch)) linking_plus_to_minus_all_types = [] for atom_type in list(atom_map_redux.keys()): # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity @@ -76,12 +77,12 @@ def category_to_atoms_axiom_links(category, categories_to_atoms): word, cat = category.split(':') return category_to_atoms_axiom_links(cat, categories_to_atoms) elif True in res: - return " " + category + return [category] else: category_cut = regex.match(regex_categories_axiom_links, category).groups() category_cut = [cat for cat in category_cut if cat is not None] for cat in category_cut: - categories_to_atoms += category_to_atoms_axiom_links(cat, "") + categories_to_atoms += category_to_atoms_axiom_links(cat, []) return categories_to_atoms @@ -94,23 +95,22 @@ def get_atoms_links_batch(category_batch): """ batch = [] for sentence in category_batch: - categories_to_atoms = "" + categories_to_atoms = [] for category in sentence: if category != "let" and not category.startswith("GOAL:"): - categories_to_atoms += category_to_atoms_axiom_links(category, "") - categories_to_atoms += " [SEP]" - categories_to_atoms = categories_to_atoms.lstrip() + categories_to_atoms += category_to_atoms_axiom_links(category, []) + categories_to_atoms.append("[SEP]") elif category.startswith("GOAL:"): - categories_to_atoms += category_to_atoms_axiom_links(category, "") - categories_to_atoms = categories_to_atoms.lstrip() + categories_to_atoms = category_to_atoms_axiom_links(category, []) + categories_to_atoms batch.append(categories_to_atoms) return batch print("test to create links ", get_axiom_links(20, torch.stack([torch.as_tensor( - [False, True, False, False, False, True, False, True, False, False, True, False, False, False, True, False, - False, True, False, True, False, False, True, False, False, False, True])]), + [True, False, True, False, False, False, True, False, True, False, + False, True, False, False, False, True, False, False, True, False, + True, False, False, True, False, False, False, False, False, False])]), [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']])) @@ -132,12 +132,12 @@ def category_to_atoms(category, categories_to_atoms): word, cat = category.split(':') return category_to_atoms(cat, categories_to_atoms) elif True in res: - return " " + category + return [category] else: category_cut = regex.match(regex_categories, category).groups() category_cut = [cat for cat in category_cut if cat is not None] for cat in category_cut: - categories_to_atoms += category_to_atoms(cat, "") + categories_to_atoms += category_to_atoms(cat, []) return categories_to_atoms @@ -150,17 +150,17 @@ def get_atoms_batch(category_batch): """ batch = [] for sentence in category_batch: - categories_to_atoms = "" + categories_to_atoms = [] for category in sentence: if category != "let": - categories_to_atoms += category_to_atoms(category, "") - categories_to_atoms += " [SEP]" - categories_to_atoms = categories_to_atoms.lstrip() + categories_to_atoms += category_to_atoms(category, []) + categories_to_atoms.append("[SEP]") batch.append(categories_to_atoms) return batch -print(" test for get atoms in categories on ['dr(0,s,np)', 'let']", get_atoms_batch([["dr(0,s,np)", "let"]])) +print(" test for get atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']", + get_atoms_batch([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']])) # endregion @@ -201,7 +201,7 @@ def get_num_atoms_batch(category_batch, max_len_sentence): """ batch = [] for sentence in category_batch: - num_atoms_sentence = [] + num_atoms_sentence = [0] for category in sentence: num_atoms_in_word = category_to_num_atoms(category, 0) # add 1 because for word we have SEP at the end @@ -309,8 +309,7 @@ def find_pos_neg_idexes(atoms_batch): return list_batch -print( - " test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']", +print(" test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np'] \n", find_pos_neg_idexes([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']])) @@ -319,36 +318,32 @@ print( # region get atoms and polarities with GOAL -def get_GOAL(max_atoms_in_sentence, categories_batch): +def get_GOAL(max_len_sentence, max_atoms_in_sentence, df_axiom_links): + categories_batch = df_axiom_links["Z"] + categories_with_goal = df_axiom_links["Y"] polarities = find_pos_neg_idexes(categories_batch) atoms_batch = get_atoms_batch(categories_batch) - atoms_batch_for_polarities = list( - map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch)) + num_atoms_batch = get_num_atoms_batch(categories_batch, max_len_sentence) for s_idx in range(len(atoms_batch)): - for atom_type in list(atom_map_redux.keys()): - list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i] - and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))] - list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i] - and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))] - while len(list_minus) != len(list_plus): - if len(list_minus) > len(list_plus): - atoms_batch[s_idx] += " " + atom_type - atoms_batch_for_polarities[s_idx].append(atom_type) - polarities[s_idx].append(True) - else: - atoms_batch[s_idx] += " " + atom_type - atoms_batch_for_polarities[s_idx].append(atom_type) - polarities[s_idx].append(False) - list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i] - and atoms_batch_for_polarities[s_idx][i] == atom_type] - list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i] - and atoms_batch_for_polarities[s_idx][i] == atom_type] + goal = categories_with_goal[s_idx][-1] + polarities_goal = category_to_atoms_polarity(goal, True) + goal = re.search(r"(\w+)_\d+", goal).groups()[0] + atoms = category_to_atoms(goal, []) + + atoms_batch[s_idx] = atoms + atoms_batch[s_idx] # + ["[SEP]"] + polarities[s_idx] = polarities_goal + polarities[s_idx] # + False + num_atoms_batch[s_idx][0] += len(atoms) # +1 return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], - max_len=max_atoms_in_sentence, padding_value=0) + max_len=max_atoms_in_sentence, padding_value=0), num_atoms_batch -print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)", "s"]])) +df_axiom_links = pd.DataFrame({"Z": [['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', + 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']], + "Y": [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', + 'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', + 'GOAL:np_7']]}) +print(" test for get GOAL ", get_GOAL(10, 30, df_axiom_links)) # endregion @@ -356,13 +351,11 @@ print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)", # region get idx for pos and neg def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence): - atoms_batch_for_polarities = list( - map(lambda sentence: sentence.split(" "), atoms_batch)) pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if - bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", - atoms_batch_for_polarities[s_idx][i])) and + bool( + re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and atoms_polarity_batch[s_idx][i]]) - for s_idx, sentence in enumerate(atoms_batch_for_polarities)], + for s_idx, sentence in enumerate(atoms_batch)], max_len=max_atoms_in_one_type // 2, padding_value=-1) for atom_type in list(atom_map_redux.keys())] @@ -370,24 +363,23 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_at def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence): - atoms_batch_for_polarities = list( - map(lambda sentence: sentence.split(" "), atoms_batch)) pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if - bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", - atoms_batch_for_polarities[s_idx][i])) and not - atoms_polarity_batch[s_idx][i]]) - for s_idx, sentence in enumerate(atoms_batch_for_polarities)], + bool( + re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and + not atoms_polarity_batch[s_idx][i]]) + for s_idx, sentence in enumerate(atoms_batch)], max_len=max_atoms_in_one_type // 2, padding_value=-1) for atom_type in list(atom_map_redux.keys())] return torch.stack(pos_idx).permute(1, 0, 2) -print(" test for cut into pos neg on ['s np [SEP] s [SEP] np s s n n']", get_neg_idx(['s np [SEP] s [SEP] np s s n n'], - torch.as_tensor( - [[False, True, False, False, - False, False, True, True, - False, True, - False, False]]), 10, 50)) +print(" test for cut into pos neg on ['dr(0,s,np)', 's']", + get_neg_idx([['s', 's', 'np', 's', 'np', '[SEP]', 's', '[SEP]']], + torch.as_tensor( + [[True, True, False, False, + True, False, False, False, + False, False, + False, False]]), 10, 50)) # endregion diff --git a/requirements.txt b/requirements.txt index c117e5384efe2b3cf6820f46d061d68059858966..eb0c0900dc80b14a240fb6e452b6a5a044d3e7b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,13 @@ numpy==1.22.2 -transformers==4.16.2 -torch==1.9.0 huggingface-hub==0.4.0 pandas==1.4.1 -sentencepiece -git+https://gitlab.irit.fr/pnria/global-helper/deepgrail-rnn/ \ No newline at end of file +Markdown==3.3.6 +packaging==21.3 +scikit-learn==1.0.2 +scipy==1.8.0 +sentencepiece==0.1.96 +tensorflow==2.9.1 +tensorboard==2.8.0 +torch==1.11.0 +tqdm==4.64.0 +transformers==4.19.0