From 1f43915a006c860577b7710e01d4706c3fc0cf1d Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Tue, 17 May 2022 17:14:18 +0200 Subject: [PATCH] update linker encoding --- Configuration/config.ini | 7 +-- Linker/Linker.py | 92 ++++++++++++++++++++++++++++------------ Linker/__init__.py | 1 + Linker/utils_linker.py | 72 ++++++++++++++++--------------- main.py | 4 +- train.py | 9 +--- 6 files changed, 108 insertions(+), 77 deletions(-) diff --git a/Configuration/config.ini b/Configuration/config.ini index 15547f6..c79def5 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -31,9 +31,4 @@ device=cpu batch_size=32 epoch=20 seed_val=42 -learning_rate=0.005 -use_checkpoint_SAVE=0 -output_path=Output -use_checkpoint_LOAD=0 -input_path=Input -model_to_load=model_check.pt \ No newline at end of file +learning_rate=0.005 \ No newline at end of file diff --git a/Linker/Linker.py b/Linker/Linker.py index 7f7462d..d2a4d89 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -37,7 +37,7 @@ class Linker(Module): self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) learning_rate = float(Configuration.modelTrainingConfig['learning_rate']) self.dropout = Dropout(0.1) - self.device = "" + self.device = "cpu" self.Supertagger = supertagger @@ -66,6 +66,16 @@ class Linker(Module): num_training_steps=100) def __preprocess_data(self, batch_size, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.0): + r""" + Args: + batch_size : int + df_axiom_links pandas DataFrame + sentences_tokens + sentences_mask + validation_rate + Returns: + the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask + """ atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"]) atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch) @@ -98,22 +108,21 @@ class Linker(Module): return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1) def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None): - r''' - Parameters : - atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories - atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities - sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context - sents_mask - Returns : - link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) - ''' + r""" + Args: + atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories + atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities + sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context + sents_mask : mask from BERT tokenizer + Returns: + link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) + """ # atoms embedding atoms_embedding = self.atoms_embedding(atoms_batch_tokenized) # MHA ou LSTM avec sortie de BERT - batch_size, _, _ = sents_embedding.shape - sents_mask = torch.randn(batch_size * self.nhead, self.max_atoms_in_sentence, self.max_len_sentence) + sents_mask = sents_mask.unsqueeze(1).repeat(self.nhead, self.max_atoms_in_sentence, 1).to(torch.float64) atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, self.make_decoder_mask(atoms_batch_tokenized)) @@ -147,15 +156,35 @@ class Linker(Module): def train_linker(self, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.1, epochs=20, batch_size=32, checkpoint=True, validate=True): - + r""" + Args: + df_axiom_links : pandas dataFrame containing the atoms anoted with _i + sentences_tokens : sentences tokenized by BERT + sentences_mask : mask of tokens + validation_rate : float + epochs : int + batch_size : int + checkpoint : boolean + validate : boolean + Returns: + Final accuracy and final loss + """ training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, sentences_tokens, sentences_mask, validation_rate) - for epoch_i in range(0, epochs): epoch_acc, epoch_loss = self.train_epoch(training_dataloader, validation_dataloader, checkpoint, validate) def train_epoch(self, training_dataloader, validation_dataloader, checkpoint=True, validate=True): + r""" Train epoch + + Args: + training_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks + validation_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks + Returns: + accuracy on validation set + loss on train set + """ # Reset the total loss for this epoch. epoch_loss = 0 @@ -195,8 +224,8 @@ class Linker(Module): print("Average Loss on train dataset : ", avg_train_loss) if checkpoint: - checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M')) - self.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt')) + self.__checkpoint_save( + path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt')) if validate: with torch.no_grad(): @@ -204,17 +233,20 @@ class Linker(Module): print("Average Loss on test dataset : ", average_test_loss) print("Average Accuracy on test dataset : ", accuracy) + print('\n') + return accuracy, avg_train_loss def predict(self, categories, sents_embedding, sents_mask=None): - r''' - Parameters : - categories : (batch_size, len_sentence) - sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context - sents_mask - Returns : - axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat) - ''' + r"""Prediction from categories output by BERT and hidden_state from BERT + + Args: + categories : (batch_size, len_sentence) + sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context + sents_mask + Returns: + axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat) + """ self.eval() # get atoms @@ -268,8 +300,9 @@ class Linker(Module): batch_sentences_tokens = batch[3].to("cuda" if torch.cuda.is_available() else "cpu") batch_sentences_mask = batch[4].to("cuda" if torch.cuda.is_available() else "cpu") - logits_axiom_links_pred = self.forward(batch_atoms, batch_polarity, batch_sentences_tokens, - batch_sentences_mask) + logits, sentences_embedding = self.Supertagger.foward(batch_sentences_tokens, batch_sentences_mask) + logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding, + batch_sentences_mask) logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3) axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3) @@ -315,7 +348,10 @@ class Linker(Module): print("#" * 15) def __checkpoint_save(self, path='/linker.pt'): - self.linker.cpu() + """ + @param path: + """ + self.cpu() torch.save({ 'args': dict(atom_map=self.atom_map, max_atoms_in_sentence=self.max_atoms_in_sentence), @@ -325,4 +361,4 @@ class Linker(Module): 'neg_transformation': self.neg_transformation.state_dict(), 'optimizer': self.optimizer, }, path) - self.linker.to(self.device) + #self.to(self.device) diff --git a/Linker/__init__.py b/Linker/__init__.py index e69de29..c0df5b8 100644 --- a/Linker/__init__.py +++ b/Linker/__init__.py @@ -0,0 +1 @@ +from .Linker import Linker \ No newline at end of file diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py index da295de..13c63f4 100644 --- a/Linker/utils_linker.py +++ b/Linker/utils_linker.py @@ -32,14 +32,14 @@ regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)' def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): - r''' - Parameters : - max_atoms_in_one_type : configuration - atoms_polarity : (batch_size, max_atoms_in_sentence) - batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms - Returns : - batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms - ''' + r""" + Args: + max_atoms_in_one_type : configuration + atoms_polarity : (batch_size, max_atoms_in_sentence) + batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms + Returns: + batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms + """ atoms_batch = get_atoms_links_batch(batch_axiom_links) linking_plus_to_minus_all_types = [] for atom_type in list(atom_map.keys())[:-1]: @@ -62,13 +62,13 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): def category_to_atoms_axiom_links(category, categories_to_atoms): - r''' - Parameters : - category - categories_to_atoms : recursive list + r""" + Args: + category : str of kind AtomCat | CategoryCat(dr or dl) + categories_to_atoms : recursive list Returns : - List of atoms inside the category in prefix order - ''' + List of atoms inside the category in prefix order + """ res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()] if category.startswith("GOAL:"): word, cat = category.split(':') @@ -85,7 +85,8 @@ def category_to_atoms_axiom_links(category, categories_to_atoms): def get_atoms_links_batch(category_batch): r""" - category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order + Args: + category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order Returns : (batch_size, max_atoms_in_sentence) flattened categories in prefix order """ @@ -104,13 +105,13 @@ def get_atoms_links_batch(category_batch): def category_to_atoms(category, categories_to_atoms): - r''' - Parameters : - category - categories_to_atoms : recursive list - Returns : - List of atoms inside the category in prefix order - ''' + r""" + Args: + category : str of kind AtomCat | CategoryCat(dr or dl) + categories_to_atoms : recursive list + Returns: + List of atoms inside the category in prefix order + """ res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()] if category.startswith("GOAL:"): word, cat = category.split(':') @@ -129,8 +130,9 @@ def category_to_atoms(category, categories_to_atoms): def get_atoms_batch(category_batch): r""" - category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order - Returns : + Args: + category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order + Returns: (batch_size, max_atoms_in_sentence) flattened categories in prefix order """ batch = [] @@ -147,12 +149,13 @@ def get_atoms_batch(category_batch): ######################################################################################### def category_to_atoms_polarity(category, polarity): - r''' - Parameters : - category : str of kind AtomCat | CategoryCat(dr or dl) - Returns : - Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes - ''' + r""" + Args: + category : str of kind AtomCat | CategoryCat(dr or dl) + polarity : polarity according to recursivity + Returns: + Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes + """ category_to_polarity = [] res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()] if category.startswith("GOAL:"): @@ -233,10 +236,11 @@ def category_to_atoms_polarity(category, polarity): def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch): r""" - max_atoms_in_sentence : configuration - atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order - Returns : - (batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order + Args: + max_atoms_in_sentence : configuration + atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order + Returns: + (batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order """ list_batch = [] for sentence in atoms_batch: diff --git a/main.py b/main.py index 55e8c52..14d3fc0 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,8 @@ import torch.nn.functional as F import torch from Configuration import Configuration -from Linker.Linker import Linker -from Supertagger.SuperTagger.SuperTagger import SuperTagger +from Linker import * +from Supertagger import * max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) diff --git a/train.py b/train.py index f83a951..bc2f785 100644 --- a/train.py +++ b/train.py @@ -1,12 +1,10 @@ import torch - from Configuration import Configuration -from Linker.Linker import Linker -from Supertagger.SuperTagger.SuperTagger import SuperTagger +from Linker import * +from Supertagger import * from utils import read_csv_pgbar torch.cuda.empty_cache() - batch_size = int(Configuration.modelTrainingConfig['batch_size']) nb_sentences = batch_size * 10 epochs = int(Configuration.modelTrainingConfig['epoch']) @@ -15,14 +13,11 @@ file_path_axiom_links = 'Datasets/aa1_links_dataset_links.csv' df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) sentences_batch = df_axiom_links["Sentences"].tolist() - supertagger = SuperTagger() supertagger.load_weights("models/model_supertagger.pt") - sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) print("Linker") linker = Linker(supertagger) - print("Linker Training") linker.train_linker(df_axiom_links, sents_tokenized, sents_mask, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, validate=True) -- GitLab