diff --git a/Linker/Linker.py b/Linker/Linker.py index a3306db59234b41f443f5e38a1dd7a584504b355..6d14226412d5ca9e58b38ac19f11b7dd8d865597 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -11,7 +11,6 @@ from torch.optim import AdamW from torch.utils.data import TensorDataset, random_split from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from transformers import get_cosine_schedule_with_warmup from Configuration import Configuration from Linker.AtomEmbedding import AtomEmbedding @@ -92,6 +91,9 @@ class Linker(Module): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.to(self.device) + + def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1): r""" Args: @@ -161,13 +163,13 @@ class Linker(Module): [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx) for s_idx in range(len(atoms_polarity_batch))], padding_value=0, - max_len=self.max_atoms_in_one_type // 2).to(self.device) + max_len=self.max_atoms_in_one_type // 2) neg_encoding = pad_sequence( [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx) for s_idx in range(len(atoms_polarity_batch))], padding_value=0, - max_len=self.max_atoms_in_one_type // 2).to(self.device) + max_len=self.max_atoms_in_one_type // 2) pos_encoding = self.pos_transformation(pos_encoding) neg_encoding = self.neg_transformation(neg_encoding) @@ -195,8 +197,6 @@ class Linker(Module): """ training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, validation_rate) - self.to(self.device) - if checkpoint or tensorboard: checkpoint_dir, writer = output_create_dir() @@ -210,8 +210,7 @@ class Linker(Module): print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}') print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%') if validation_rate > 0.0: - with torch.no_grad(): - loss_test, accuracy_test = self.eval_epoch(validation_dataloader, self.cross_entropy_loss) + loss_test, accuracy_test = self.eval_epoch(validation_dataloader, self.cross_entropy_loss) print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%') if checkpoint: @@ -241,14 +240,13 @@ class Linker(Module): accuracy on validation set loss on train set """ + self.train() # Reset the total loss for this epoch. epoch_loss = 0 accuracy_train = 0 t0 = time.time() - self.train() - # For each batch of training data... with tqdm(training_dataloader, unit="batch") as tepoch: for batch in tepoch: @@ -299,44 +297,44 @@ class Linker(Module): axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat) """ self.eval() - + with torch.no_grad(): # get atoms - atoms_batch = get_atoms_batch(categories) - atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) + atoms_batch = get_atoms_batch(categories) + atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) - # get polarities - polarities = find_pos_neg_idexes(self.max_atoms_in_sentence, categories) + # get polarities + polarities = find_pos_neg_idexes(self.max_atoms_in_sentence, categories) - # atoms embedding - atoms_embedding = self.atoms_embedding(atoms_tokenized) + # atoms embedding + atoms_embedding = self.atoms_embedding(atoms_tokenized) - # MHA ou LSTM avec sortie de BERT - atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, - self.make_decoder_mask(atoms_tokenized)) + # MHA ou LSTM avec sortie de BERT + atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, + self.make_decoder_mask(atoms_tokenized)) - link_weights = [] - for atom_type in list(self.atom_map.keys())[:-1]: - pos_encoding = pad_sequence( - [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized, - polarities, atom_type, s_idx) - for s_idx in range(len(polarities))], padding_value=0, - max_len=self.max_atoms_in_one_type // 2).to(self.device) + link_weights = [] + for atom_type in list(self.atom_map.keys())[:-1]: + pos_encoding = pad_sequence( + [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized, + polarities, atom_type, s_idx) + for s_idx in range(len(polarities))], padding_value=0, + max_len=self.max_atoms_in_one_type // 2) - neg_encoding = pad_sequence( - [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized, - polarities, atom_type, s_idx) - for s_idx in range(len(polarities))], padding_value=0, - max_len=self.max_atoms_in_one_type // 2).to(self.device) + neg_encoding = pad_sequence( + [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized, + polarities, atom_type, s_idx) + for s_idx in range(len(polarities))], padding_value=0, + max_len=self.max_atoms_in_one_type // 2) - pos_encoding = self.pos_transformation(pos_encoding) - neg_encoding = self.neg_transformation(neg_encoding) + pos_encoding = self.pos_transformation(pos_encoding) + neg_encoding = self.neg_transformation(neg_encoding) - weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) - link_weights.append(sinkhorn(weights, iters=3)) + weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) + link_weights.append(sinkhorn(weights, iters=3)) - logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3) - axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3) - return axiom_links + logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3) + axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3) + return axiom_links def eval_batch(self, batch, cross_entropy_loss): batch_atoms = batch[0].to(self.device) @@ -361,12 +359,14 @@ class Linker(Module): Args: dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols """ + self.eval() accuracy_average = 0 loss_average = 0 - for step, batch in enumerate(dataloader): - loss, accuracy = self.eval_batch(batch, cross_entropy_loss) - accuracy_average += accuracy - loss_average += float(loss) + with torch.no_grad(): + for step, batch in enumerate(dataloader): + loss, accuracy = self.eval_batch(batch, cross_entropy_loss) + accuracy_average += accuracy + loss_average += float(loss) return loss_average / len(dataloader), accuracy_average / len(dataloader)