Skip to content
Snippets Groups Projects
Commit a76e9380 authored by Julien Rabault's avatar Julien Rabault
Browse files

Merge branch 'version-linker' of...

Merge branch 'version-linker' of https://gitlab.irit.fr/pnria/global-helper/deepgrail-linker into version-linker

# Conflicts:
#	Configuration/config.ini
#	train.py
parents 9c45838b 1b94de2d
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
......@@ -12,15 +12,14 @@ max_atoms_in_one_type=250
dim_encoder = 768
[MODEL_DECODER]
dim_decoder = 16
num_rnn_layers=1
dim_decoder = 32
dropout=0.1
teacher_forcing=0.05
[MODEL_LINKER]
nhead=4
dim_feedforward=246
dim_embedding_atoms=16
dim_embedding_atoms=32
dim_polarity_transfo=128
layer_norm_eps=1e-5
dropout=0.1
......@@ -31,4 +30,4 @@ device=cpu
batch_size=16
epoch=20
seed_val=42
learning_rate=2e-5
\ No newline at end of file
learning_rate=2e-5
......@@ -9,6 +9,7 @@ import sys
from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter
from transformers import get_cosine_schedule_with_warmup
from Configuration import Configuration
......@@ -23,6 +24,19 @@ from Linker.eval import mesure_accuracy, SinkhornLoss
from utils import pad_sequence
def output_create_dir():
"""
Create le output dir for tensorboard and checkpoint
@return: output dir, tensorboard writter
"""
from datetime import datetime
outpout_path = 'TensorBoard'
training_dir = os.path.join(outpout_path, 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
logs_dir = os.path.join(training_dir, 'logs')
writer = SummaryWriter(log_dir=logs_dir)
return training_dir, writer
class Linker(Module):
def __init__(self, supertagger):
super(Linker, self).__init__()
......@@ -67,20 +81,20 @@ class Linker(Module):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def __preprocess_data(self, batch_size, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.0):
def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.0):
r"""
Args:
batch_size : int
df_axiom_links pandas DataFrame
sentences_tokens
sentences_mask
validation_rate
Returns:
the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask
"""
sentences_batch = df_axiom_links["Sentences"].tolist()
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"])
atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch)
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["sub_tree"])
......@@ -153,29 +167,56 @@ class Linker(Module):
return F.log_softmax(link_weights_per_batch, dim=3)
def train_linker(self, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, validate=True):
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, tensorboard=False):
r"""
Args:
df_axiom_links : pandas dataFrame containing the atoms anoted with _i
sentences_tokens : sentences tokenized by BERT
sentences_mask : mask of tokens
validation_rate : float
epochs : int
batch_size : int
checkpoint : boolean
validate : boolean
tensorboard : boolean
Returns:
Final accuracy and final loss
"""
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
sentences_tokens, sentences_mask,
validation_rate)
self.to(self.device)
for epoch_i in range(0, epochs):
epoch_acc, epoch_loss = self.train_epoch(training_dataloader, validation_dataloader, checkpoint, validate)
def train_epoch(self, training_dataloader, validation_dataloader, checkpoint=True, validate=True):
if checkpoint or tensorboard:
checkpoint_dir, writer = output_create_dir()
for epoch_i in range(epochs):
avg_train_loss, avg_accuracy_train = self.train_epoch(training_dataloader)
print("Average Loss on train dataset : ", avg_train_loss)
print("Average Accuracy on train dataset : ", avg_accuracy_train)
if checkpoint:
self.__checkpoint_save(
path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
if validation_rate > 0.0:
with torch.no_grad():
loss_test, accuracy_test = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
print("Average Loss on test dataset : ", loss_test)
print("Average Accuracy on test dataset : ", accuracy_test)
if tensorboard:
writer.add_scalars(f'Accuracy', {
'Train': avg_accuracy_train}, epoch_i)
writer.add_scalars(f'Loss', {
'Train': avg_train_loss}, epoch_i)
if validation_rate > 0.0:
writer.add_scalars(f'Accuracy', {
'Validation': accuracy_test}, epoch_i)
writer.add_scalars(f'Loss', {
'Validation': loss_test}, epoch_i)
print('\n')
def train_epoch(self, training_dataloader):
r""" Train epoch
Args:
......@@ -188,6 +229,7 @@ class Linker(Module):
# Reset the total loss for this epoch.
epoch_loss = 0
accuracy_train = 0
self.train()
......@@ -220,22 +262,13 @@ class Linker(Module):
self.optimizer.step()
self.scheduler.step()
avg_train_loss = epoch_loss / len(training_dataloader)
print("Average Loss on train dataset : ", avg_train_loss)
pred_axiom_links = torch.argmax(logits_predictions, dim=3)
accuracy_train += mesure_accuracy(batch_true_links, pred_axiom_links)
if checkpoint:
self.__checkpoint_save(
path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
if validate:
with torch.no_grad():
accuracy, average_test_loss = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
print("Average Loss on test dataset : ", average_test_loss)
print("Average Accuracy on test dataset : ", accuracy)
print('\n')
avg_train_loss = epoch_loss / len(training_dataloader)
avg_accuracy_train = accuracy_train / len(training_dataloader)
return accuracy, avg_train_loss
return avg_train_loss, avg_accuracy_train
def predict(self, categories, sents_embedding, sents_mask=None):
r"""Prediction from categories output by BERT and hidden_state from BERT
......@@ -302,7 +335,7 @@ class Linker(Module):
accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links)
return accuracy, loss
return loss, accuracy
def eval_epoch(self, dataloader, cross_entropy_loss):
r"""Average the evaluation of all the batch.
......@@ -312,14 +345,12 @@ class Linker(Module):
"""
accuracy_average = 0
loss_average = 0
compt = 0
for step, batch in enumerate(dataloader):
compt += 1
accuracy, loss = self.eval_batch(batch, cross_entropy_loss)
loss, accuracy = self.eval_batch(batch, cross_entropy_loss)
accuracy_average += accuracy
loss_average += loss
loss_average += float(loss)
return accuracy_average / compt, loss_average / compt
return loss_average / len(dataloader), accuracy_average / len(dataloader)
def load_weights(self, model_file):
print("#" * 15)
......
......@@ -6,7 +6,7 @@ from utils import read_csv_pgbar
torch.cuda.empty_cache()
batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 200
nb_sentences = batch_size * 20
epochs = int(Configuration.modelTrainingConfig['epoch'])
file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment