diff --git a/Linker/Linker.py b/Linker/Linker.py index f65325eef04bc224f025f209f99f9d1a6f653207..ca3a8dcf4b6f00d62a315a2645ee4c59296828f8 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -1,16 +1,24 @@ +import os +from datetime import datetime + import torch from torch.nn import Sequential, LayerNorm, Dropout from torch.nn import Module import torch.nn.functional as F import sys + +from torch.optim import AdamW +from torch.utils.data import TensorDataset, random_split +from transformers import get_cosine_schedule_with_warmup + from Configuration import Configuration from AtomEmbedding import AtomEmbedding from AtomTokenizer import AtomTokenizer from MHA import AttentionDecoderLayer from atom_map import atom_map from Sinkhorn import sinkhorn_fn_no_exp as sinkhorn -from utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN -from eval import mesure_accuracy +from utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links +from eval import mesure_accuracy, SinkhornLoss from ..utils import pad_sequence @@ -27,6 +35,10 @@ class Linker(Module): self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type']) self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) + batch_size = int(Configuration.modelTrainingConfig['batch_size']) + nb_sentences = batch_size * 10 + self.epochs = int(Configuration.modelTrainingConfig['epoch']) + learning_rate = float(Configuration.modelTrainingConfig['learning_rate']) self.dropout = Dropout(0.1) self.device = "" @@ -47,6 +59,41 @@ class Linker(Module): LayerNorm(self.dim_embedding_atoms, eps=1e-12) ) + self.cross_entropy_loss = SinkhornLoss() + self.optimizer = AdamW(self.parameters(), + weight_decay=1e-5, + lr=learning_rate) + self.scheduler = get_cosine_schedule_with_warmup(self.optimizer, + num_warmup_steps=0, + num_training_steps=100) + + def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.0): + atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"]) + atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) + atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch) + + atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["sub_tree"]) + + torch.set_printoptions(edgeitems=20) + truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch, + df_axiom_links["sub_tree"]) + truth_links_batch = truth_links_batch.permute(1, 0, 2) + + # Construction tensor dataset + dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch) + + if validation_rate > 0: + train_size = int(0.9 * len(dataset)) + val_size = len(dataset) - train_size + train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) + validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + else: + validation_dataloader = None + train_dataset = dataset + + training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False) + return training_dataloader, validation_dataloader + def make_decoder_mask(self, atoms_token): decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64) decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0 @@ -101,6 +148,63 @@ class Linker(Module): return torch.stack(link_weights) + def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, batch_size=32): + + training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, validation_rate) + epochs = epochs - self.epochs + self.train() + + for epoch_i in range(0, epochs): + epoch_acc, epoch_loss = self.__train_epoch(training_dataloader, validation_dataloader) + + def __train_epoch(self, training_dataloader, validation_dataloader, checkpoint=True, validate=True): + + # Reset the total loss for this epoch. + epoch_loss = 0 + + self.train() + + # For each batch of training data... + for step, batch in enumerate(training_dataloader): + # Unpack this training batch from our dataloader + batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") + batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") + batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") + # batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu") + + self.optimizer.zero_grad() + + # get sentence embedding from BERT which is already trained + # sentences_embedding = supertagger(batch_sentences) + + # Run the kinker on the categories predictions + logits_predictions = self(batch_atoms, batch_polarity, []) + + linker_loss = self.cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links) + # Perform a backward pass to calculate the gradients. + epoch_loss += float(linker_loss) + linker_loss.backward() + + # This is to help prevent the "exploding gradients" problem. + # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2) + + # Update parameters and take a step using the computed gradient. + self.optimizer.step() + self.scheduler.step() + + avg_train_loss = epoch_loss / len(training_dataloader) + + if checkpoint: + checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M')) + self.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt')) + + if validate: + self.eval() + with torch.no_grad(): + accuracy, average_test_loss = self.eval_epoch(validation_dataloader, self.cross_entropy_loss) + + return accuracy, avg_train_loss + def predict(self, categories, sents_embedding, sents_mask=None): r''' Parameters : diff --git a/Utils/PostpreprocesTXT.py b/Utils/PostpreprocesTXT.py index a1848e235f0d308b462c16ecdf2d90248e50a1f1..eaa9d30efb4d8bf0ae844cbc7b174c237ffc5f0c 100644 --- a/Utils/PostpreprocesTXT.py +++ b/Utils/PostpreprocesTXT.py @@ -25,7 +25,7 @@ def sub_tree_line(line_with_data: str): for word_with_data in line_list: w, t = sub_tree_word(word_with_data) sentence += ' ' + w - if t not in ["\\", "/", "let"] and len(t)>0: + if t not in ["\\", "/", "let"] and len(t) > 0: sub_trees.append([t]) """if ('ppp' in list(itertools.chain(*sub_trees))): print(sentence)""" @@ -35,17 +35,9 @@ def sub_tree_line(line_with_data: str): def Txt_to_csv(file_name: str): file = open(file_name, "r", encoding="utf8") text = file.readlines() - sub = [sub_tree_line(data) for data in text] - df = pd.DataFrame(data=sub, columns=['Sentences', 'sub_tree']) - df.to_csv("../Datasets/" + file_name[:-4] + "_dataset_links.csv", index=False) Txt_to_csv("aa1_links.txt") - -"""trees = df['sub_tree'] -trees_flat = set(list(itertools.chain(*list(itertools.chain(*trees))))) -fruit_dictionary = dict(zip(list(trees_flat), range(len(list(trees_flat))))) -print(fruit_dictionary)"""