diff --git a/Output/Tranning_02-05_08-57/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651476773.montana.6860.1 b/Output/Tranning_02-05_08-57/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651476773.montana.6860.1 deleted file mode 100644 index c5715fc44ef387ce7abb6b3341d37769ed78cd7c..0000000000000000000000000000000000000000 Binary files a/Output/Tranning_02-05_08-57/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651476773.montana.6860.1 and /dev/null differ diff --git a/Output/Tranning_02-05_08-57/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651476773.montana.6860.2 b/Output/Tranning_02-05_08-57/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651476773.montana.6860.2 deleted file mode 100644 index e5f81db34af3cd706eea16ef9d3bdc047ce0cae0..0000000000000000000000000000000000000000 Binary files a/Output/Tranning_02-05_08-57/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651476773.montana.6860.2 and /dev/null differ diff --git a/Output/Tranning_02-05_08-57/logs/events.out.tfevents.1651474623.montana.6860.0 b/Output/Tranning_02-05_08-57/logs/events.out.tfevents.1651474623.montana.6860.0 deleted file mode 100644 index efbfc7d8a55d2cb0f5ed2b1178f473f0e8e93be7..0000000000000000000000000000000000000000 Binary files a/Output/Tranning_02-05_08-57/logs/events.out.tfevents.1651474623.montana.6860.0 and /dev/null differ diff --git a/Output/Tranning_02-05_11-36/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651488612.montana.35996.1 b/Output/Tranning_02-05_11-36/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651488612.montana.35996.1 deleted file mode 100644 index 5cfc526c1a26ffcabcf14d00f42cc75bf018a1c4..0000000000000000000000000000000000000000 Binary files a/Output/Tranning_02-05_11-36/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651488612.montana.35996.1 and /dev/null differ diff --git a/Output/Tranning_02-05_11-36/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651488612.montana.35996.2 b/Output/Tranning_02-05_11-36/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651488612.montana.35996.2 deleted file mode 100644 index 138bd57e18ec66a2f43191cb7b7af73fe3d56277..0000000000000000000000000000000000000000 Binary files a/Output/Tranning_02-05_11-36/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651488612.montana.35996.2 and /dev/null differ diff --git a/Output/Tranning_02-05_11-36/logs/events.out.tfevents.1651484213.montana.35996.0 b/Output/Tranning_02-05_11-36/logs/events.out.tfevents.1651484213.montana.35996.0 deleted file mode 100644 index 9883309c9bb3dda787beef437a6a4df6ae931b3f..0000000000000000000000000000000000000000 Binary files a/Output/Tranning_02-05_11-36/logs/events.out.tfevents.1651484213.montana.35996.0 and /dev/null differ diff --git a/Output/Tranning_29-04_14-57/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651239629.montana.55449.1 b/Output/Tranning_29-04_14-57/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651239629.montana.55449.1 deleted file mode 100644 index ae9e1c819bafc59a81b9616301304b4eb3485cb8..0000000000000000000000000000000000000000 Binary files a/Output/Tranning_29-04_14-57/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651239629.montana.55449.1 and /dev/null differ diff --git a/Output/Tranning_29-04_14-57/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651239629.montana.55449.2 b/Output/Tranning_29-04_14-57/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651239629.montana.55449.2 deleted file mode 100644 index 95b5a29a10dc72589e9a0cd1bee8aac5c20574f2..0000000000000000000000000000000000000000 Binary files a/Output/Tranning_29-04_14-57/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651239629.montana.55449.2 and /dev/null differ diff --git a/SuperTagger/EncoderDecoder.py b/SuperTagger/EncoderDecoder.py deleted file mode 100644 index 36311d552b3b54fb3c4cf0135ccb077989ef79a8..0000000000000000000000000000000000000000 --- a/SuperTagger/EncoderDecoder.py +++ /dev/null @@ -1,104 +0,0 @@ -import torch -from torch.nn import Dropout -from torch.nn import Module - -from Configuration import Configuration -from SuperTagger.Decoder.RNNDecoderLayer import RNNDecoderLayer -from SuperTagger.Encoder.EncoderLayer import EncoderLayer -from SuperTagger.eval import measure_supertagging_accuracy - - -class EncoderDecoder(Module): - """ - A standard Encoder-Decoder architecture. Base for this and many - other models. - - decoder : instance of Decoder - """ - - def __init__(self, BASE_TOKENIZER, BASE_MODEL, symbols_map): - super(EncoderDecoder, self).__init__() - - self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence']) - self.max_symbols_in_sentence = int(Configuration.datasetConfig['max_symbols_in_sentence']) - self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder']) - - self.symbols_map = symbols_map - self.sents_padding_id = BASE_TOKENIZER.pad_token_id - self.sents_space_id = BASE_TOKENIZER.bos_token_id - - self.encoder = EncoderLayer(BASE_MODEL) - self.decoder = RNNDecoderLayer(self.symbols_map) - - self.dropout = Dropout(0.1) - - def forward(self, sents_tokenized_batch, sents_mask_batch, symbols_tokenized_batch): - r"""Training the translation from sentence to symbols - - Args: - sents_tokenized_batch: [batch_size, max_len_sentence] the tokenized sentences - sents_mask_batch : mask output from the encoder tokenizer - symbols_tokenized_batch: [batch_size, max_symbols_in_sentence] the true symbols for each sentence. - """ - last_hidden_state, pooler_output = self.encoder([sents_tokenized_batch, sents_mask_batch]) - last_hidden_state = self.dropout(last_hidden_state) - return self.decoder(symbols_tokenized_batch, last_hidden_state, pooler_output) - - def decode_greedy_rnn(self, sents_tokenized_batch, sents_mask_batch): - r"""Predicts the symbols for each sentence in sents_tokenized_batch. - - Args: - sents_tokenized_batch: [batch_size, max_len_sentence] the tokenized sentences - sents_mask_batch : mask output from the encoder tokenizer - """ - last_hidden_state, pooler_output = self.encoder([sents_tokenized_batch, sents_mask_batch]) - last_hidden_state = self.dropout(last_hidden_state) - - predictions = self.decoder.predict_rnn(last_hidden_state, pooler_output) - - return predictions - - def eval_batch(self, batch, cross_entropy_loss): - r"""Calls the evaluating methods after predicting the symbols from the sentence contained in batch - - Args: - batch: contains the tokenized sentences, their masks and the true symbols. - """ - b_sents_tokenized = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") - b_sents_mask = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") - b_symbols_tokenized = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") - - type_predictions = self.decode_greedy_rnn(b_sents_tokenized, b_sents_mask) - - pred = torch.argmax(type_predictions, dim=2) - - predict_trad = [{v: k for k, v in self.symbols_map.items()}[int(i)] for i in pred[0]] - true_trad = [{v: k for k, v in self.symbols_map.items()}[int(i)] for i in b_symbols_tokenized[0]] - l = len([i for i in true_trad if i != '[PAD]']) - print("\nsub true (", l, ") : ", - [token for token in true_trad if token != '[PAD]']) - print("\nsub predict (", len([i for i in predict_trad if i != '[PAD]']), ") : ", - [token for token in predict_trad if token != '[PAD]']) - - return measure_supertagging_accuracy(pred, b_symbols_tokenized, - ignore_idx=self.symbols_map["[PAD]"]), float( - cross_entropy_loss(type_predictions, b_symbols_tokenized)) - - def eval_epoch(self, dataloader, cross_entropy_loss): - r"""Average the evaluation of all the batch. - - Args: - dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols - """ - s_total, s_correct, w_total, w_correct = (0.1,) * 4 - - for step, batch in enumerate(dataloader): - batch = batch - batch_output, loss = self.eval_batch(batch, cross_entropy_loss) - ((bs_correct, bs_total), (bw_correct, bw_total)) = batch_output - s_total += bs_total - s_correct += bs_correct - w_total += bw_total - w_correct += bw_correct - - return s_correct / s_total, w_correct / w_total, loss diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py index 6a39ae7658faf456301399cb811d634c8d9f0d38..1f439200a3ea7b520d6b4c26de098772fdaf6bdc 100644 --- a/SuperTagger/Linker/Linker.py +++ b/SuperTagger/Linker/Linker.py @@ -44,7 +44,7 @@ class Linker(Module): LayerNorm(self.dim_polarity_transfo, eps=1e-12) ) - def make_decoder_mask(self, atoms_batch) : + def make_decoder_mask(self, atoms_batch): decoder_attn_mask = torch.ones_like(atoms_batch, dtype=torch.float64) decoder_attn_mask[atoms_batch.eq(self.padding_id)] = 0.0 return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_batch.shape[1], 1) @@ -83,10 +83,12 @@ class Linker(Module): # to do select with list of list pos_encoding = pad_sequence( [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence)) - for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, padding_value=0) + for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, + padding_value=0) neg_encoding = pad_sequence( [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence)) - for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, padding_value=0) + for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, + padding_value=0) # pos_encoding = self.pos_transformation(pos_encoding) # neg_encoding = self.neg_transformation(neg_encoding) @@ -95,3 +97,12 @@ class Linker(Module): link_weights.append(sinkhorn(weights, iters=3)) return link_weights + + def predict_axiom_links(self): + return None + + def eval_batch(self): + return None + + def eval_epoch(self): + return None diff --git a/SuperTagger/__pycache__/EncoderDecoder.cpython-38.pyc b/SuperTagger/__pycache__/EncoderDecoder.cpython-38.pyc deleted file mode 100644 index 7745cbbcf4e476115694b33969f29399433509b5..0000000000000000000000000000000000000000 Binary files a/SuperTagger/__pycache__/EncoderDecoder.cpython-38.pyc and /dev/null differ diff --git a/SuperTagger/eval.py b/SuperTagger/eval.py index 372e68cab3a4cf4af8da655a7073f372de39cb5e..7a14ac5e20c2e15d723c978c02382dc7ee5ad72c 100644 --- a/SuperTagger/eval.py +++ b/SuperTagger/eval.py @@ -3,45 +3,6 @@ from torch import Tensor from torch.nn import Module from torch.nn.functional import nll_loss, cross_entropy -# Another from Kokos function to calculate the accuracy of our predictions vs labels -def measure_supertagging_accuracy(pred, truth, ignore_idx=0): - r"""Evaluation of the decoder's output and the true symbols without considering the padding in the true targets. - - Args: - pred: [batch_size, max_symbols_in_sentence] prediction of symbols for each symbols. - truth: [batch_size, max_symbols_in_sentence] true values of symbols - """ - correct_symbols = torch.ones(pred.size()) - correct_symbols[pred != truth] = 0 - correct_symbols[truth == ignore_idx] = 1 - num_correct_symbols = correct_symbols.sum().item() - num_masked_symbols = len(truth[truth == ignore_idx]) - - correct_sentences = correct_symbols.prod(dim=1) - num_correct_sentences = correct_sentences.sum().item() - - return (num_correct_sentences, pred.shape[0]), \ - (num_correct_symbols - num_masked_symbols, pred.shape[0] * pred.shape[1] - num_masked_symbols) - - -def count_sep(xs, sep_id, dim=-1): - return xs.eq(sep_id).sum(dim=dim) - - -class NormCrossEntropy(Module): - r"""Loss based on the cross entropy, it considers the number of words and ignore the padding. - """ - - def __init__(self, ignore_index, sep_id, weights=None): - super(NormCrossEntropy, self).__init__() - self.ignore_index = ignore_index - self.sep_id = sep_id - self.weights = weights - - def forward(self, predictions, truths): - return cross_entropy(predictions.flatten(0, -2), truths.flatten(), weight=self.weights, - reduction='sum', ignore_index=self.ignore_index) / count_sep(truths.flatten(), self.sep_id) - class SinkhornLoss(Module): def __init__(self): @@ -49,4 +10,4 @@ class SinkhornLoss(Module): def forward(self, predictions, truths): return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean') - for link, perm in zip(predictions, truths)) \ No newline at end of file + for link, perm in zip(predictions, truths)) diff --git a/SuperTagger/utils.py b/SuperTagger/utils.py index cfacf2503ba924d1fd3ba07b340e27a2b13a2002..fc1511ee9cf9df9ec70f908f2b5542e61aea4971 100644 --- a/SuperTagger/utils.py +++ b/SuperTagger/utils.py @@ -55,32 +55,4 @@ def format_time(elapsed): elapsed_rounded = int(round(elapsed)) # Format as hh:mm:ss - return str(datetime.timedelta(seconds=elapsed_rounded)) - - -def checkpoint_save(model, opt, epoch, dir, loss): - torch.save({ - 'epoch': epoch, - 'model_state_dict': model.state_dict(), - 'optimizer_state_dict': opt.state_dict(), - 'loss': loss, - }, dir + '/model_check.pt') - - -def checkpoint_load(model, opt, path): - epoch = 0 - loss = 0 - print("#" * 15) - - try: - checkpoint = torch.load(path) - model.load_state_dict(checkpoint['model_state_dict']) - opt.load_state_dict(checkpoint['optimizer_state_dict']) - epoch = checkpoint['epoch'] - loss = checkpoint['loss'] - print("\n The loading checkpoint was successful ! \n") - print("#" * 10) - except Exception as e: - print("\nCan't load checkpoint model because : " + str(e) + "\n\nUse default model \n") - print("#" * 15) - return model, opt, epoch, loss + return str(datetime.timedelta(seconds=elapsed_rounded)) \ No newline at end of file diff --git a/train.py b/train.py index 25154db1b58abb38c2e7a6ffb94f89fcfb4d4b0d..9287436a5c86f2cd4c2c1fc3548b4a0c46d304b3 100644 --- a/train.py +++ b/train.py @@ -1,7 +1,5 @@ -import os import pickle import time -from datetime import datetime import numpy as np import torch @@ -13,8 +11,8 @@ from transformers import get_cosine_schedule_with_warmup from Configuration import Configuration from SuperTagger.Linker.Linker import Linker from SuperTagger.Linker.atom_map import atom_map -from SuperTagger.eval import NormCrossEntropy, SinkhornLoss -from SuperTagger.utils import format_time, read_csv_pgbar, checkpoint_save, checkpoint_load +from SuperTagger.eval import SinkhornLoss +from SuperTagger.utils import format_time, read_csv_pgbar from torch.utils.tensorboard import SummaryWriter