diff --git a/Configuration/config.ini b/Configuration/config.ini index c79def55882180c36facd03f2d0d9593501a13c6..d6c860553e48b27f2c16d5d1c2228c86e51b7fc4 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -12,15 +12,15 @@ max_atoms_in_one_type=250 dim_encoder = 768 [MODEL_DECODER] -dim_decoder = 8 +dim_decoder = 16 num_rnn_layers=1 dropout=0.1 teacher_forcing=0.05 [MODEL_LINKER] -nhead=1 +nhead=4 dim_feedforward=246 -dim_embedding_atoms=8 +dim_embedding_atoms=16 dim_polarity_transfo=128 layer_norm_eps=1e-5 dropout=0.1 diff --git a/Linker/Linker.py b/Linker/Linker.py index 12f9534edfc8fe5a35fbad8b62dbc2fe7ee44d67..ce256406be8271ec9a951af3f08f7cd7995c7a5a 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -17,7 +17,8 @@ from Linker.AtomTokenizer import AtomTokenizer from Linker.MHA import AttentionDecoderLayer from Linker.atom_map import atom_map from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn -from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links +from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \ + get_neg_encoding_for_s_idx from Linker.eval import mesure_accuracy, SinkhornLoss from utils import pad_sequence @@ -130,23 +131,17 @@ class Linker(Module): link_weights = [] for atom_type in list(self.atom_map.keys())[:-1]: - pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) - if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and - atoms_batch_tokenized[s_idx][i] == self.atom_map[ - atom_type] and - atoms_polarity_batch[s_idx][i])] + [ - torch.zeros(self.dim_embedding_atoms, device=self.device)]).to(self.device) - for s_idx in range(len(atoms_polarity_batch))], padding_value=0, - max_len=self.max_atoms_in_one_type // 2).to(self.device) - - neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) - if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and - atoms_batch_tokenized[s_idx][i] == self.atom_map[ - atom_type] and - not atoms_polarity_batch[s_idx][i])] + [ - torch.zeros(self.dim_embedding_atoms, device=self.device)]).to(self.device) - for s_idx in range(len(atoms_polarity_batch))], padding_value=0, - max_len=self.max_atoms_in_one_type // 2).to(self.device) + pos_encoding = pad_sequence( + [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, + atoms_polarity_batch, atom_type, s_idx) + for s_idx in range(len(atoms_polarity_batch))], padding_value=0, + max_len=self.max_atoms_in_one_type // 2).to(self.device) + + neg_encoding = pad_sequence( + [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, + atoms_polarity_batch, atom_type, s_idx) + for s_idx in range(len(atoms_polarity_batch))], padding_value=0, + max_len=self.max_atoms_in_one_type // 2).to(self.device) pos_encoding = self.pos_transformation(pos_encoding) neg_encoding = self.neg_transformation(neg_encoding) @@ -271,23 +266,17 @@ class Linker(Module): link_weights = [] for atom_type in list(self.atom_map.keys())[:-1]: - pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) - if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and - atoms_tokenized[s_idx][i] == self.atom_map[ - atom_type] and - polarities[s_idx][i])] + [ - torch.zeros(self.dim_embedding_atoms, device=self.device)]) - for s_idx in range(len(polarities))], padding_value=0, - max_len=self.max_atoms_in_one_type // 2).to(self.device) - - neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) - if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and - atoms_tokenized[s_idx][i] == self.atom_map[ - atom_type] and - not polarities[s_idx][i])] + [ - torch.zeros(self.dim_embedding_atoms, device=self.device)]) - for s_idx in range(len(polarities))], padding_value=0, - max_len=self.max_atoms_in_one_type // 2).to(self.device) + pos_encoding = pad_sequence( + [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized, + polarities, atom_type, s_idx) + for s_idx in range(len(polarities))], padding_value=0, + max_len=self.max_atoms_in_one_type // 2).to(self.device) + + neg_encoding = pad_sequence( + [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized, + polarities, atom_type, s_idx) + for s_idx in range(len(polarities))], padding_value=0, + max_len=self.max_atoms_in_one_type // 2).to(self.device) pos_encoding = self.pos_transformation(pos_encoding) neg_encoding = self.neg_transformation(neg_encoding) diff --git a/Linker/__init__.py b/Linker/__init__.py index c0df5b8d2f6b10dc52709b2bd7b132eb1c1c2066..c2a9483d03b868e0c2b00cae7a54f7bb7b7bd4db 100644 --- a/Linker/__init__.py +++ b/Linker/__init__.py @@ -1 +1,3 @@ -from .Linker import Linker \ No newline at end of file +from .Linker import Linker +from .atom_map import atom_map +from .AtomEmbedding import AtomEmbedding \ No newline at end of file diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py index 13c63f47346f9aac35ac230734994cfb227d036b..0821f6196c55a3c0961ecc89c23aaacde8f53140 100644 --- a/Linker/utils_linker.py +++ b/Linker/utils_linker.py @@ -27,7 +27,7 @@ regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)' ######################################################################################### -################################ Liste des atoms avc _i######################################## +################################ Liste des atoms avec _i ######################################## ######################################################################################### @@ -72,7 +72,7 @@ def category_to_atoms_axiom_links(category, categories_to_atoms): res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()] if category.startswith("GOAL:"): word, cat = category.split(':') - return [cat] + return category_to_atoms_axiom_links(cat, categories_to_atoms) elif True in res: return [category] else: @@ -103,7 +103,6 @@ def get_atoms_links_batch(category_batch): ################################ Liste des atoms ######################################## ######################################################################################### - def category_to_atoms(category, categories_to_atoms): r""" Args: @@ -115,8 +114,7 @@ def category_to_atoms(category, categories_to_atoms): res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()] if category.startswith("GOAL:"): word, cat = category.split(':') - category = re.match(r'([a-zA-Z|_]+)_\d+', cat).group(1) - return [category] + return category_to_atoms(cat, categories_to_atoms) elif True in res: category = re.match(r'([a-zA-Z|_]+)_\d+', category).group(1) return [category] @@ -158,78 +156,41 @@ def category_to_atoms_polarity(category, polarity): """ category_to_polarity = [] res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()] + + # mot final if category.startswith("GOAL:"): - category_to_polarity.append(True) + word, cat = category.split(':') + res = [bool(re.match(r'' + atom_type + "_\d+", cat)) for atom_type in atom_map.keys()] + if True in res: + category_to_polarity.append(True) + else: + category_to_polarity += category_to_atoms_polarity(cat, True) + + # le mot a une category atomique elif True in res or category.startswith("dia") or category.startswith("box"): - category_to_polarity.append(False) + category_to_polarity.append(not polarity) + + # sinon c'est une formule longue else: # dr = / if category.startswith("dr"): category_cut = regex.match(regex_categories, category).groups() category_cut = [cat for cat in category_cut if cat is not None] left_side, right_side = category_cut[0], category_cut[1] - - if polarity == True: - # for the left side : normal - res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()] - if True in res or left_side.startswith("dia") or left_side.startswith("box"): - category_to_polarity.append(False) - else: - category_to_polarity += category_to_atoms_polarity(left_side, True) - # for the right side : change polarity for next right formula - res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()] - if True in res or right_side.startswith("dia") or right_side.startswith("box"): - category_to_polarity.append(True) - else: - category_to_polarity += category_to_atoms_polarity(right_side, False) - - else: - # for the left side - res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()] - if True in res or left_side.startswith("dia") or left_side.startswith("box"): - category_to_polarity.append(True) - else: - category_to_polarity += category_to_atoms_polarity(left_side, False) - # for the right side : change polarity for next right formula - res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()] - if True in res or right_side.startswith("dia") or right_side.startswith("box"): - category_to_polarity.append(False) - else: - category_to_polarity += category_to_atoms_polarity(right_side, True) + # for the left side + category_to_polarity += category_to_atoms_polarity(left_side, polarity) + # for the right side : change polarity for next right formula + category_to_polarity += category_to_atoms_polarity(right_side, not polarity) # dl = \ elif category.startswith("dl"): category_cut = regex.match(regex_categories, category).groups() category_cut = [cat for cat in category_cut if cat is not None] left_side, right_side = category_cut[0], category_cut[1] - - if polarity == True: - # for the left side : change polarity - res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()] - if True in res or left_side.startswith("dia") or left_side.startswith("box"): - category_to_polarity.append(True) - else: - category_to_polarity += category_to_atoms_polarity(left_side, False) - # for the right side : normal - res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()] - if True in res or right_side.startswith("dia") or right_side.startswith("box"): - category_to_polarity.append(False) - else: - category_to_polarity += category_to_atoms_polarity(right_side, True) - - else: - # for the left side - res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()] - if True in res or left_side.startswith("dia") or left_side.startswith("box"): - category_to_polarity.append(False) - else: - category_to_polarity += category_to_atoms_polarity(left_side, True) - # for the right side - res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()] - if True in res or right_side.startswith("dia") or right_side.startswith("box"): - category_to_polarity.append(True) - else: - category_to_polarity += category_to_atoms_polarity(right_side, False) + # for the left side + category_to_polarity += category_to_atoms_polarity(left_side, not polarity) + # for the right side + category_to_polarity += category_to_atoms_polarity(right_side, polarity) return category_to_polarity @@ -251,3 +212,32 @@ def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch): list_batch.append(torch.as_tensor(list_atoms)) return pad_sequence([list_batch[i] for i in range(len(list_batch))], max_len=max_atoms_in_sentence, padding_value=0) + + +######################################################################################### +################################ Prepare encoding ############################################### +######################################################################################### + + +def get_pos_encoding_for_s_idx(dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, + atom_type, s_idx): + pos_encoding = [x for i, x in enumerate(atoms_encoding[s_idx]) + if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and + atoms_batch_tokenized[s_idx][i] == atom_map[atom_type] and + atoms_polarity_batch[s_idx][i])] + if len(pos_encoding) == 0: + return torch.zeros(1, dim_embedding_atoms, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) + else: + return torch.stack(pos_encoding) + + +def get_neg_encoding_for_s_idx(dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, + atom_type, s_idx): + neg_encoding = [x for i, x in enumerate(atoms_encoding[s_idx]) + if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and + atoms_batch_tokenized[s_idx][i] == atom_map[atom_type] and + not atoms_polarity_batch[s_idx][i])] + if len(neg_encoding) == 0: + return torch.zeros(1, dim_embedding_atoms, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) + else: + return torch.stack(neg_encoding) diff --git a/train.py b/train.py index b2e73259a8ad242fb2103e22dbedf531d9e9c1d2..2c502001a16b1dbb13f19082f4c39ecc4dfbb4c5 100644 --- a/train.py +++ b/train.py @@ -6,7 +6,7 @@ from utils import read_csv_pgbar torch.cuda.empty_cache() batch_size = int(Configuration.modelTrainingConfig['batch_size']) -nb_sentences = batch_size * 10 +nb_sentences = batch_size * 200 epochs = int(Configuration.modelTrainingConfig['epoch']) file_path_axiom_links = 'Datasets/aa1_links_dataset_links.csv' @@ -15,8 +15,6 @@ df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) sentences_batch = df_axiom_links["Sentences"].tolist() supertagger = SuperTagger() supertagger.load_weights("models/model_supertagger.pt") - - sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) print("Linker")