diff --git a/Configuration/config.ini b/Configuration/config.ini index 2a7495fed35267f6d1582ca71c192dd5640bf7af..6373a249e15389034550d0a0ac75bf8097d1c31e 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -3,7 +3,7 @@ transformers = 4.16.2 [DATASET_PARAMS] symbols_vocab_size=26 -atom_vocab_size=20 +atom_vocab_size=17 max_len_sentence=109 max_atoms_in_sentence=1250 max_atoms_in_one_type=324 diff --git a/Linker/Linker.py b/Linker/Linker.py index 6d14226412d5ca9e58b38ac19f11b7dd8d865597..e7f2bac722d66fdb226d1c86c4daa45057b8dad7 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -70,8 +70,10 @@ class Linker(Module): self.Supertagger = supertagger self.atom_map = atom_map + self.sub_atoms_type_list = ['cl_r', 'pp', 'n', 'np', 'cl_y', 'txt', 's'] self.padding_id = self.atom_map['[PAD]'] self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) + self.inverse_map = self.atoms_tokenizer.inverse_atom_map self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, atom_vocab_size, self.padding_id) self.linker_encoder = AttentionDecoderLayer() @@ -93,7 +95,6 @@ class Linker(Module): self.to(self.device) - def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1): r""" Args: @@ -112,7 +113,7 @@ class Linker(Module): atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["Z"]) - truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch, + truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, self.sub_atoms_type_list, atoms_polarity_batch, df_axiom_links["Y"]) truth_links_batch = truth_links_batch.permute(1, 0, 2) @@ -158,16 +159,16 @@ class Linker(Module): self.make_decoder_mask(atoms_batch_tokenized)) link_weights = [] - for atom_type in list(self.atom_map.keys())[:-1]: + for atom_type in self.sub_atoms_type_list: pos_encoding = pad_sequence( [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, - atoms_polarity_batch, atom_type, s_idx) + atoms_polarity_batch, atom_type, self.inverse_map, s_idx) for s_idx in range(len(atoms_polarity_batch))], padding_value=0, max_len=self.max_atoms_in_one_type // 2) neg_encoding = pad_sequence( [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, - atoms_polarity_batch, atom_type, s_idx) + atoms_polarity_batch, atom_type, self.inverse_map, s_idx) for s_idx in range(len(atoms_polarity_batch))], padding_value=0, max_len=self.max_atoms_in_one_type // 2) @@ -298,7 +299,7 @@ class Linker(Module): """ self.eval() with torch.no_grad(): - # get atoms + # get atoms atoms_batch = get_atoms_batch(categories) atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) @@ -313,16 +314,16 @@ class Linker(Module): self.make_decoder_mask(atoms_tokenized)) link_weights = [] - for atom_type in list(self.atom_map.keys())[:-1]: + for atom_type in self.sub_atoms_type_list: pos_encoding = pad_sequence( [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized, - polarities, atom_type, s_idx) + polarities, atom_type, self.inverse_map, s_idx) for s_idx in range(len(polarities))], padding_value=0, max_len=self.max_atoms_in_one_type // 2) neg_encoding = pad_sequence( [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized, - polarities, atom_type, s_idx) + polarities, atom_type, self.inverse_map, s_idx) for s_idx in range(len(polarities))], padding_value=0, max_len=self.max_atoms_in_one_type // 2) diff --git a/Linker/atom_map.py b/Linker/atom_map.py index 301bf71cee949fc2f2c94eade8f2bf5b326f2d1c..d45c4b9709ee161960302e485f01a39e08c3fc76 100644 --- a/Linker/atom_map.py +++ b/Linker/atom_map.py @@ -1,22 +1,19 @@ atom_map = \ {'cl_r': 0, - "pp":1, + "pp": 1, 'n': 2, - 'p': 3, - 's_ppres': 4, - 'dia': 5, - 's_whq': 6, - 's_q': 7, - 'np': 8, - 's_inf': 9, - 's_pass': 10, - 'pp_a': 11, - 'pp_par': 12, - 'pp_de': 13, - 'cl_y': 14, - 'box': 15, - 'txt': 16, - 's': 17, - 's_ppart': 18, - '[PAD]': 19 + 's_ppres': 3, + 's_whq': 4, + 's_q': 5, + 'np': 6, + 's_inf': 7, + 's_pass': 8, + 'pp_a': 9, + 'pp_par': 10, + 'pp_de': 11, + 'cl_y': 12, + 'txt': 13, + 's': 14, + 's_ppart': 15, + '[PAD]': 16 } diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py index 2ab2a2dff8cc4b5cf5deba9501c49dc6b2dc6483..0863f9e2ed1371c90e9aa3981c11e196bc98d16b 100644 --- a/Linker/utils_linker.py +++ b/Linker/utils_linker.py @@ -23,6 +23,10 @@ class FFN(Module): return self.ffn(x) +######################################################################################### +################################ Regex ######################################## +######################################################################################### + regex_categories_axiom_links = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)' regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)' @@ -32,10 +36,11 @@ regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)' ######################################################################################### -def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): +def get_axiom_links(max_atoms_in_one_type, sub_atoms_type_list, atoms_polarity, batch_axiom_links): r""" Args: max_atoms_in_one_type : configuration + sub_atoms_type_list : list of atom type to match atoms_polarity : (batch_size, max_atoms_in_sentence) batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms Returns: @@ -43,7 +48,7 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): """ atoms_batch = get_atoms_links_batch(batch_axiom_links) linking_plus_to_minus_all_types = [] - for atom_type in list(atom_map.keys())[:-1]: + for atom_type in sub_atoms_type_list: # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity l_polarity_plus = [[x for i, x in enumerate(atoms_batch[s_idx]) if atoms_polarity[s_idx, i] and bool(re.search(atom_type + "_", atoms_batch[s_idx][i]))] for s_idx in @@ -74,6 +79,8 @@ def category_to_atoms_axiom_links(category, categories_to_atoms): if category.startswith("GOAL:"): word, cat = category.split(':') return category_to_atoms_axiom_links(cat, categories_to_atoms) + elif category == "let": + return [] elif True in res: return [category] else: @@ -112,16 +119,15 @@ def category_to_atoms(category, categories_to_atoms): Returns: List of atoms inside the category in prefix order """ - res = [bool(re.match(r'' + atom_type, category)) for atom_type in atom_map.keys()] + res = [(category == atom_type) for atom_type in atom_map.keys()] if category.startswith("GOAL:"): word, cat = category.split(':') return category_to_atoms(cat, categories_to_atoms) - elif True in res: - return [category] elif category == "let": return [] + elif True in res: + return [category] else: - print(category) category_cut = regex.match(regex_categories, category).groups() category_cut = [cat for cat in category_cut if cat is not None] for cat in category_cut: @@ -158,7 +164,7 @@ def category_to_atoms_polarity(category, polarity): Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes """ category_to_polarity = [] - res = [bool(re.match(r'' + atom_type, category)) for atom_type in atom_map.keys()] + res = [(category == atom_type) for atom_type in atom_map.keys()] # mot final if category.startswith("GOAL:"): @@ -223,10 +229,10 @@ def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch): def get_pos_encoding_for_s_idx(dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, - atom_type, s_idx): + atom_type, inverse_map, s_idx): pos_encoding = [x for i, x in enumerate(atoms_encoding[s_idx]) if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and - atoms_batch_tokenized[s_idx][i] == atom_map[atom_type] and + bool(re.match(r'' + atom_type + '_?\w*', inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and atoms_polarity_batch[s_idx][i])] if len(pos_encoding) == 0: return torch.zeros(1, dim_embedding_atoms, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) @@ -235,10 +241,10 @@ def get_pos_encoding_for_s_idx(dim_embedding_atoms, atoms_encoding, atoms_batch_ def get_neg_encoding_for_s_idx(dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, - atom_type, s_idx): + atom_type, inverse_map, s_idx): neg_encoding = [x for i, x in enumerate(atoms_encoding[s_idx]) if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and - atoms_batch_tokenized[s_idx][i] == atom_map[atom_type] and + bool(re.match(r'' + atom_type + '_?\w*', inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and not atoms_polarity_batch[s_idx][i])] if len(neg_encoding) == 0: return torch.zeros(1, dim_embedding_atoms, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))