Skip to content
Snippets Groups Projects
Commit 8dc363bd authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

starting train

parent 85e492e2
Branches
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
Showing
with 18 additions and 180 deletions
File deleted
File deleted
import torch
from torch.nn import Dropout
from torch.nn import Module
from Configuration import Configuration
from SuperTagger.Decoder.RNNDecoderLayer import RNNDecoderLayer
from SuperTagger.Encoder.EncoderLayer import EncoderLayer
from SuperTagger.eval import measure_supertagging_accuracy
class EncoderDecoder(Module):
"""
A standard Encoder-Decoder architecture. Base for this and many
other models.
decoder : instance of Decoder
"""
def __init__(self, BASE_TOKENIZER, BASE_MODEL, symbols_map):
super(EncoderDecoder, self).__init__()
self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
self.max_symbols_in_sentence = int(Configuration.datasetConfig['max_symbols_in_sentence'])
self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
self.symbols_map = symbols_map
self.sents_padding_id = BASE_TOKENIZER.pad_token_id
self.sents_space_id = BASE_TOKENIZER.bos_token_id
self.encoder = EncoderLayer(BASE_MODEL)
self.decoder = RNNDecoderLayer(self.symbols_map)
self.dropout = Dropout(0.1)
def forward(self, sents_tokenized_batch, sents_mask_batch, symbols_tokenized_batch):
r"""Training the translation from sentence to symbols
Args:
sents_tokenized_batch: [batch_size, max_len_sentence] the tokenized sentences
sents_mask_batch : mask output from the encoder tokenizer
symbols_tokenized_batch: [batch_size, max_symbols_in_sentence] the true symbols for each sentence.
"""
last_hidden_state, pooler_output = self.encoder([sents_tokenized_batch, sents_mask_batch])
last_hidden_state = self.dropout(last_hidden_state)
return self.decoder(symbols_tokenized_batch, last_hidden_state, pooler_output)
def decode_greedy_rnn(self, sents_tokenized_batch, sents_mask_batch):
r"""Predicts the symbols for each sentence in sents_tokenized_batch.
Args:
sents_tokenized_batch: [batch_size, max_len_sentence] the tokenized sentences
sents_mask_batch : mask output from the encoder tokenizer
"""
last_hidden_state, pooler_output = self.encoder([sents_tokenized_batch, sents_mask_batch])
last_hidden_state = self.dropout(last_hidden_state)
predictions = self.decoder.predict_rnn(last_hidden_state, pooler_output)
return predictions
def eval_batch(self, batch, cross_entropy_loss):
r"""Calls the evaluating methods after predicting the symbols from the sentence contained in batch
Args:
batch: contains the tokenized sentences, their masks and the true symbols.
"""
b_sents_tokenized = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
b_sents_mask = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
b_symbols_tokenized = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
type_predictions = self.decode_greedy_rnn(b_sents_tokenized, b_sents_mask)
pred = torch.argmax(type_predictions, dim=2)
predict_trad = [{v: k for k, v in self.symbols_map.items()}[int(i)] for i in pred[0]]
true_trad = [{v: k for k, v in self.symbols_map.items()}[int(i)] for i in b_symbols_tokenized[0]]
l = len([i for i in true_trad if i != '[PAD]'])
print("\nsub true (", l, ") : ",
[token for token in true_trad if token != '[PAD]'])
print("\nsub predict (", len([i for i in predict_trad if i != '[PAD]']), ") : ",
[token for token in predict_trad if token != '[PAD]'])
return measure_supertagging_accuracy(pred, b_symbols_tokenized,
ignore_idx=self.symbols_map["[PAD]"]), float(
cross_entropy_loss(type_predictions, b_symbols_tokenized))
def eval_epoch(self, dataloader, cross_entropy_loss):
r"""Average the evaluation of all the batch.
Args:
dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols
"""
s_total, s_correct, w_total, w_correct = (0.1,) * 4
for step, batch in enumerate(dataloader):
batch = batch
batch_output, loss = self.eval_batch(batch, cross_entropy_loss)
((bs_correct, bs_total), (bw_correct, bw_total)) = batch_output
s_total += bs_total
s_correct += bs_correct
w_total += bw_total
w_correct += bw_correct
return s_correct / s_total, w_correct / w_total, loss
......@@ -44,7 +44,7 @@ class Linker(Module):
LayerNorm(self.dim_polarity_transfo, eps=1e-12)
)
def make_decoder_mask(self, atoms_batch) :
def make_decoder_mask(self, atoms_batch):
decoder_attn_mask = torch.ones_like(atoms_batch, dtype=torch.float64)
decoder_attn_mask[atoms_batch.eq(self.padding_id)] = 0.0
return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_batch.shape[1], 1)
......@@ -83,10 +83,12 @@ class Linker(Module):
# to do select with list of list
pos_encoding = pad_sequence(
[atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, padding_value=0)
for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=self.max_atoms_in_sentence,
padding_value=0)
neg_encoding = pad_sequence(
[atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, padding_value=0)
for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=self.max_atoms_in_sentence,
padding_value=0)
# pos_encoding = self.pos_transformation(pos_encoding)
# neg_encoding = self.neg_transformation(neg_encoding)
......@@ -95,3 +97,12 @@ class Linker(Module):
link_weights.append(sinkhorn(weights, iters=3))
return link_weights
def predict_axiom_links(self):
return None
def eval_batch(self):
return None
def eval_epoch(self):
return None
File deleted
......@@ -3,45 +3,6 @@ from torch import Tensor
from torch.nn import Module
from torch.nn.functional import nll_loss, cross_entropy
# Another from Kokos function to calculate the accuracy of our predictions vs labels
def measure_supertagging_accuracy(pred, truth, ignore_idx=0):
r"""Evaluation of the decoder's output and the true symbols without considering the padding in the true targets.
Args:
pred: [batch_size, max_symbols_in_sentence] prediction of symbols for each symbols.
truth: [batch_size, max_symbols_in_sentence] true values of symbols
"""
correct_symbols = torch.ones(pred.size())
correct_symbols[pred != truth] = 0
correct_symbols[truth == ignore_idx] = 1
num_correct_symbols = correct_symbols.sum().item()
num_masked_symbols = len(truth[truth == ignore_idx])
correct_sentences = correct_symbols.prod(dim=1)
num_correct_sentences = correct_sentences.sum().item()
return (num_correct_sentences, pred.shape[0]), \
(num_correct_symbols - num_masked_symbols, pred.shape[0] * pred.shape[1] - num_masked_symbols)
def count_sep(xs, sep_id, dim=-1):
return xs.eq(sep_id).sum(dim=dim)
class NormCrossEntropy(Module):
r"""Loss based on the cross entropy, it considers the number of words and ignore the padding.
"""
def __init__(self, ignore_index, sep_id, weights=None):
super(NormCrossEntropy, self).__init__()
self.ignore_index = ignore_index
self.sep_id = sep_id
self.weights = weights
def forward(self, predictions, truths):
return cross_entropy(predictions.flatten(0, -2), truths.flatten(), weight=self.weights,
reduction='sum', ignore_index=self.ignore_index) / count_sep(truths.flatten(), self.sep_id)
class SinkhornLoss(Module):
def __init__(self):
......@@ -49,4 +10,4 @@ class SinkhornLoss(Module):
def forward(self, predictions, truths):
return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean')
for link, perm in zip(predictions, truths))
\ No newline at end of file
for link, perm in zip(predictions, truths))
......@@ -55,32 +55,4 @@ def format_time(elapsed):
elapsed_rounded = int(round(elapsed))
# Format as hh:mm:ss
return str(datetime.timedelta(seconds=elapsed_rounded))
def checkpoint_save(model, opt, epoch, dir, loss):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': opt.state_dict(),
'loss': loss,
}, dir + '/model_check.pt')
def checkpoint_load(model, opt, path):
epoch = 0
loss = 0
print("#" * 15)
try:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
opt.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print("\n The loading checkpoint was successful ! \n")
print("#" * 10)
except Exception as e:
print("\nCan't load checkpoint model because : " + str(e) + "\n\nUse default model \n")
print("#" * 15)
return model, opt, epoch, loss
return str(datetime.timedelta(seconds=elapsed_rounded))
\ No newline at end of file
import os
import pickle
import time
from datetime import datetime
import numpy as np
import torch
......@@ -13,8 +11,8 @@ from transformers import get_cosine_schedule_with_warmup
from Configuration import Configuration
from SuperTagger.Linker.Linker import Linker
from SuperTagger.Linker.atom_map import atom_map
from SuperTagger.eval import NormCrossEntropy, SinkhornLoss
from SuperTagger.utils import format_time, read_csv_pgbar, checkpoint_save, checkpoint_load
from SuperTagger.eval import SinkhornLoss
from SuperTagger.utils import format_time, read_csv_pgbar
from torch.utils.tensorboard import SummaryWriter
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment