-
Caroline DE POURTALES authoredCaroline DE POURTALES authored
Linker.py 16.98 KiB
import os
import re
import sys
import datetime
import time
import torch
import torch.nn.functional as F
from torch.nn import Sequential, LayerNorm, Module, Linear
from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from Configuration import Configuration
from Linker.AtomTokenizer import AtomTokenizer
from Linker.PositionEncoding import PositionalEncoding
from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from Linker.atom_map import atom_map, atom_map_redux
from Linker.eval import mesure_accuracy, SinkhornLoss
from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch
from Supertagger import SuperTagger
from utils import pad_sequence
def format_time(elapsed):
'''
Takes a time in seconds and returns a string hh:mm:ss
'''
# Round to the nearest second.
elapsed_rounded = int(round(elapsed))
# Format as hh:mm:ss
return str(datetime.timedelta(seconds=elapsed_rounded))
def output_create_dir():
"""
Create le output dir for tensorboard and checkpoint
@return: output dir, tensorboard writter
"""
from datetime import datetime
outpout_path = 'TensorBoard'
training_dir = os.path.join(outpout_path, 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
logs_dir = os.path.join(training_dir, 'logs')
writer = SummaryWriter(log_dir=logs_dir)
return training_dir, writer
class Linker(Module):
def __init__(self, supertagger_path_model):
super(Linker, self).__init__()
dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
self.dim_cat_out = int(Configuration.modelLinkerConfig['dim_cat_out'])
dim_pre_sinkhorn_transfo = int(Configuration.modelLinkerConfig['dim_pre_sinkhorn_transfo'])
dim_intermediate_FFN = int(Configuration.modelLinkerConfig['dim_intermediate_FFN'])
self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
supertagger = SuperTagger()
supertagger.load_weights(supertagger_path_model)
self.Supertagger = supertagger
self.atom_map = atom_map
self.sub_atoms_type_list = list(atom_map_redux.keys())
self.padding_id = self.atom_map['[PAD]']
self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
self.inverse_map = self.atoms_tokenizer.inverse_atom_map
self.position_encoding = PositionalEncoding(dim_encoder, max_len=self.max_atoms_in_sentence)
dim_cat = dim_encoder * 2
self.linker_encoder = Linear(dim_cat, self.dim_cat_out, bias=False)
self.pos_transformation = Sequential(
FFN(self.dim_cat_out, dim_intermediate_FFN, 0.1, d_out=dim_pre_sinkhorn_transfo),
LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-12)
)
self.neg_transformation = Sequential(
FFN(self.dim_cat_out, dim_intermediate_FFN, 0.1, d_out=dim_pre_sinkhorn_transfo),
LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-12)
)
self.cross_entropy_loss = SinkhornLoss()
self.optimizer = AdamW(self.parameters(),
lr=learning_rate)
self.to(self.device)
def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
r"""
Args:
batch_size : int
df_axiom_links pandas DataFrame
validation_rate
Returns:
the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask
"""
print("Start preprocess Data")
sentences_batch = df_axiom_links["X"].tolist()
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"])
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence)
pos_idx = get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, self.max_atoms_in_one_type)
neg_idx = get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, self.max_atoms_in_one_type)
truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, self.sub_atoms_type_list, atoms_polarity_batch,
df_axiom_links["Y"])
truth_links_batch = truth_links_batch.permute(1, 0, 2)
# Construction tensor dataset
dataset = TensorDataset(num_atoms_per_word, pos_idx, neg_idx, truth_links_batch, sentences_tokens,
sentences_mask)
if validation_rate > 0.0:
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
else:
validation_dataloader = None
train_dataset = dataset
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
print("End preprocess Data")
return training_dataloader, validation_dataloader
def forward(self, batch_num_atoms_per_word, batch_pos_idx, batch_neg_idx, sents_embedding, cat_embedding):
r"""
Args:
batch_num_atoms_per_word : (batch_size, len_sentence) flattened categories
batch_pos_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities
batch_neg_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities
sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
cat_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for cat embedding
Returns:
link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) log probabilities
"""
# repeat embedding word for each atom in word
sents_embedding_repeat = pad_sequence(
[torch.repeat_interleave(input=sents_embedding[i], repeats=batch_num_atoms_per_word[i], dim=0)
for i in range(len(sents_embedding))], max_len=self.max_atoms_in_sentence, padding_value=0)
cat_embedding_repeat = pad_sequence(
[torch.repeat_interleave(input=cat_embedding[i], repeats=batch_num_atoms_per_word[i], dim=0)
for i in range(len(cat_embedding))], max_len=self.max_atoms_in_sentence, padding_value=0)
# positional encoding of atoms and cat embedding to form the atom embedding
position_encoding = self.position_encoding(cat_embedding_repeat)
# cat
atoms_sentences_encoding = torch.cat([sents_embedding_repeat, position_encoding], dim=2)
atoms_encoding = self.linker_encoder(atoms_sentences_encoding)
# linking per atom type
link_weights = []
for atom_type in self.sub_atoms_type_list:
pos_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_pos_idx, atom_type)
neg_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_neg_idx, atom_type)
pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding)
weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters))
total_link_weights = torch.stack(link_weights)
return F.log_softmax(total_link_weights, dim=3)
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, tensorboard=False):
r"""
Args:
df_axiom_links : pandas dataFrame containing the atoms anoted with _i
validation_rate : float
epochs : int
batch_size : int
checkpoint : boolean
tensorboard : boolean
Returns:
Final accuracy and final loss
"""
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
validation_rate)
if checkpoint or tensorboard:
checkpoint_dir, writer = output_create_dir()
for epoch_i in range(epochs):
print("")
print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
print('Training...')
avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader)
print("")
print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
if validation_rate > 0.0:
loss_test, accuracy_test = self.eval_epoch(validation_dataloader)
print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')
if checkpoint:
self.__checkpoint_save(
path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
if tensorboard:
writer.add_scalars(f'Accuracy', {
'Train': avg_accuracy_train}, epoch_i)
writer.add_scalars(f'Loss', {
'Train': avg_train_loss}, epoch_i)
if validation_rate > 0.0:
writer.add_scalars(f'Accuracy', {
'Validation': accuracy_test}, epoch_i)
writer.add_scalars(f'Loss', {
'Validation': loss_test}, epoch_i)
print('\n')
def train_epoch(self, training_dataloader):
r""" Train epoch
Args:
training_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks
Returns:
accuracy on validation set
loss on train set
"""
self.train()
# Reset the total loss for this epoch.
epoch_loss = 0
accuracy_train = 0
t0 = time.time()
# For each batch of training data...
with tqdm(training_dataloader, unit="batch") as tepoch:
for batch in tepoch:
# Unpack this training batch from our dataloader
batch_num_atoms = batch[0].to(self.device)
batch_pos_idx = batch[1].to(self.device)
batch_neg_idx = batch[2].to(self.device)
batch_true_links = batch[3].to(self.device)
batch_sentences_tokens = batch[4].to(self.device)
batch_sentences_mask = batch[5].to(self.device)
self.optimizer.zero_grad()
# get sentence embedding from BERT which is already trained
output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
# Run the kinker on the categories predictions
logits_predictions = self(batch_num_atoms, batch_pos_idx, batch_neg_idx, output['word_embeding'],
output['last_hidden_state'])
linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
# Perform a backward pass to calculate the gradients.
epoch_loss += float(linker_loss)
linker_loss.backward()
# This is to help prevent the "exploding gradients" problem.
# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)
# Update parameters and take a step using the computed gradient.
self.optimizer.step()
pred_axiom_links = torch.argmax(logits_predictions, dim=3)
accuracy_train += mesure_accuracy(batch_true_links, pred_axiom_links)
# Measure how long this epoch took.
training_time = format_time(time.time() - t0)
avg_train_loss = epoch_loss / len(training_dataloader)
avg_accuracy_train = accuracy_train / len(training_dataloader)
return avg_train_loss, avg_accuracy_train, training_time
def eval_batch(self, batch):
batch_num_atoms = batch[0].to(self.device)
batch_pos_idx = batch[1].to(self.device)
batch_neg_idx = batch[2].to(self.device)
batch_true_links = batch[3].to(self.device)
batch_sentences_tokens = batch[4].to(self.device)
batch_sentences_mask = batch[5].to(self.device)
output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
logits_predictions = self(batch_num_atoms, batch_pos_idx, batch_neg_idx, output['word_embeding'],
output['last_hidden_state'])
axiom_links_pred = torch.argmax(logits_predictions, dim=3)
print('\n')
print("Tokens de la phrase : ", batch_sentences_tokens[1])
print("Polarités + des atoms de la phrase : ", batch_pos_idx[1][:50])
print("Polarités - des atoms de la phrase : ", batch_neg_idx[1][:50])
print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100])
print("Les prédictions : ", axiom_links_pred[1][2][:100])
print('\n')
accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
return loss, accuracy
def eval_epoch(self, dataloader):
r"""Average the evaluation of all the batch.
Args:
dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols
"""
self.eval()
accuracy_average = 0
loss_average = 0
with torch.no_grad():
for step, batch in enumerate(dataloader):
loss, accuracy = self.eval_batch(batch)
accuracy_average += accuracy
loss_average += float(loss)
return loss_average / len(dataloader), accuracy_average / len(dataloader)
def load_weights(self, model_file):
print("#" * 15)
try:
params = torch.load(model_file, map_location=self.device)
args = params['args']
self.atom_map = args['atom_map']
self.max_atoms_in_sentence = args['max_atoms_in_sentence']
self.atoms_tokenizer = AtomTokenizer(self.atom_map, self.max_atoms_in_sentence)
self.atoms_embedding.load_state_dict(params['atoms_embedding'])
self.linker_encoder.load_state_dict(params['linker_encoder'])
self.pos_transformation.load_state_dict(params['pos_transformation'])
self.neg_transformation.load_state_dict(params['neg_transformation'])
self.optimizer.load_state_dict(params['optimizer'])
print("\n The loading checkpoint was successful ! \n")
except Exception as e:
print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr)
raise e
print("#" * 15)
def __checkpoint_save(self, path='/linker.pt'):
"""
@param path:
"""
self.cpu()
torch.save({
'args': dict(atom_map=self.atom_map, max_atoms_in_sentence=self.max_atoms_in_sentence),
'atoms_embedding': self.atoms_embedding.state_dict(),
'linker_encoder': self.linker_encoder.state_dict(),
'pos_transformation': self.pos_transformation.state_dict(),
'neg_transformation': self.neg_transformation.state_dict(),
'optimizer': self.optimizer,
}, path)
self.to(self.device)
def make_sinkhorn_inputs(self, bsd_tensor, positional_ids, atom_type):
"""
:param bsd_tensor:
Tensor of shape batch size \times sequence length \times feature dimensionality.
:param positional_ids:
A List of batch_size elements, each being a List of num_atoms LongTensors.
Each LongTensor in positional_ids[b][a] indexes the location of atoms of type a in sentence b.
:param atom_type:
:return:
"""
return torch.stack([torch.stack([bsd_tensor.select(0, index=i).select(0, index=int(atom)).to(self.device)
if atom != -1 else torch.zeros(self.dim_cat_out, device=self.device)
for atom in sentence])
for i, sentence in enumerate(positional_ids[:, atom_map_redux[atom_type], :])])