Skip to content
Snippets Groups Projects
Commit a76e9380 authored by Julien Rabault's avatar Julien Rabault
Browse files

Merge branch 'version-linker' of...

Merge branch 'version-linker' of https://gitlab.irit.fr/pnria/global-helper/deepgrail-linker into version-linker

# Conflicts:
#	Configuration/config.ini
#	train.py
parents 9c45838b 1b94de2d
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -12,15 +12,14 @@ max_atoms_in_one_type=250 ...@@ -12,15 +12,14 @@ max_atoms_in_one_type=250
dim_encoder = 768 dim_encoder = 768
[MODEL_DECODER] [MODEL_DECODER]
dim_decoder = 16 dim_decoder = 32
num_rnn_layers=1
dropout=0.1 dropout=0.1
teacher_forcing=0.05 teacher_forcing=0.05
[MODEL_LINKER] [MODEL_LINKER]
nhead=4 nhead=4
dim_feedforward=246 dim_feedforward=246
dim_embedding_atoms=16 dim_embedding_atoms=32
dim_polarity_transfo=128 dim_polarity_transfo=128
layer_norm_eps=1e-5 layer_norm_eps=1e-5
dropout=0.1 dropout=0.1
...@@ -31,4 +30,4 @@ device=cpu ...@@ -31,4 +30,4 @@ device=cpu
batch_size=16 batch_size=16
epoch=20 epoch=20
seed_val=42 seed_val=42
learning_rate=2e-5 learning_rate=2e-5
\ No newline at end of file
...@@ -9,6 +9,7 @@ import sys ...@@ -9,6 +9,7 @@ import sys
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split from torch.utils.data import TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter
from transformers import get_cosine_schedule_with_warmup from transformers import get_cosine_schedule_with_warmup
from Configuration import Configuration from Configuration import Configuration
...@@ -23,6 +24,19 @@ from Linker.eval import mesure_accuracy, SinkhornLoss ...@@ -23,6 +24,19 @@ from Linker.eval import mesure_accuracy, SinkhornLoss
from utils import pad_sequence from utils import pad_sequence
def output_create_dir():
"""
Create le output dir for tensorboard and checkpoint
@return: output dir, tensorboard writter
"""
from datetime import datetime
outpout_path = 'TensorBoard'
training_dir = os.path.join(outpout_path, 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
logs_dir = os.path.join(training_dir, 'logs')
writer = SummaryWriter(log_dir=logs_dir)
return training_dir, writer
class Linker(Module): class Linker(Module):
def __init__(self, supertagger): def __init__(self, supertagger):
super(Linker, self).__init__() super(Linker, self).__init__()
...@@ -67,20 +81,20 @@ class Linker(Module): ...@@ -67,20 +81,20 @@ class Linker(Module):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def __preprocess_data(self, batch_size, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.0): def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.0):
r""" r"""
Args: Args:
batch_size : int batch_size : int
df_axiom_links pandas DataFrame df_axiom_links pandas DataFrame
sentences_tokens
sentences_mask
validation_rate validation_rate
Returns: Returns:
the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask
""" """
sentences_batch = df_axiom_links["Sentences"].tolist()
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"]) atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"])
atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch)
atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["sub_tree"]) atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["sub_tree"])
...@@ -153,29 +167,56 @@ class Linker(Module): ...@@ -153,29 +167,56 @@ class Linker(Module):
return F.log_softmax(link_weights_per_batch, dim=3) return F.log_softmax(link_weights_per_batch, dim=3)
def train_linker(self, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.1, epochs=20, def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, validate=True): batch_size=32, checkpoint=True, tensorboard=False):
r""" r"""
Args: Args:
df_axiom_links : pandas dataFrame containing the atoms anoted with _i df_axiom_links : pandas dataFrame containing the atoms anoted with _i
sentences_tokens : sentences tokenized by BERT
sentences_mask : mask of tokens
validation_rate : float validation_rate : float
epochs : int epochs : int
batch_size : int batch_size : int
checkpoint : boolean checkpoint : boolean
validate : boolean tensorboard : boolean
Returns: Returns:
Final accuracy and final loss Final accuracy and final loss
""" """
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
sentences_tokens, sentences_mask,
validation_rate) validation_rate)
self.to(self.device) self.to(self.device)
for epoch_i in range(0, epochs):
epoch_acc, epoch_loss = self.train_epoch(training_dataloader, validation_dataloader, checkpoint, validate)
def train_epoch(self, training_dataloader, validation_dataloader, checkpoint=True, validate=True): if checkpoint or tensorboard:
checkpoint_dir, writer = output_create_dir()
for epoch_i in range(epochs):
avg_train_loss, avg_accuracy_train = self.train_epoch(training_dataloader)
print("Average Loss on train dataset : ", avg_train_loss)
print("Average Accuracy on train dataset : ", avg_accuracy_train)
if checkpoint:
self.__checkpoint_save(
path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
if validation_rate > 0.0:
with torch.no_grad():
loss_test, accuracy_test = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
print("Average Loss on test dataset : ", loss_test)
print("Average Accuracy on test dataset : ", accuracy_test)
if tensorboard:
writer.add_scalars(f'Accuracy', {
'Train': avg_accuracy_train}, epoch_i)
writer.add_scalars(f'Loss', {
'Train': avg_train_loss}, epoch_i)
if validation_rate > 0.0:
writer.add_scalars(f'Accuracy', {
'Validation': accuracy_test}, epoch_i)
writer.add_scalars(f'Loss', {
'Validation': loss_test}, epoch_i)
print('\n')
def train_epoch(self, training_dataloader):
r""" Train epoch r""" Train epoch
Args: Args:
...@@ -188,6 +229,7 @@ class Linker(Module): ...@@ -188,6 +229,7 @@ class Linker(Module):
# Reset the total loss for this epoch. # Reset the total loss for this epoch.
epoch_loss = 0 epoch_loss = 0
accuracy_train = 0
self.train() self.train()
...@@ -220,22 +262,13 @@ class Linker(Module): ...@@ -220,22 +262,13 @@ class Linker(Module):
self.optimizer.step() self.optimizer.step()
self.scheduler.step() self.scheduler.step()
avg_train_loss = epoch_loss / len(training_dataloader) pred_axiom_links = torch.argmax(logits_predictions, dim=3)
print("Average Loss on train dataset : ", avg_train_loss) accuracy_train += mesure_accuracy(batch_true_links, pred_axiom_links)
if checkpoint: avg_train_loss = epoch_loss / len(training_dataloader)
self.__checkpoint_save( avg_accuracy_train = accuracy_train / len(training_dataloader)
path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
if validate:
with torch.no_grad():
accuracy, average_test_loss = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
print("Average Loss on test dataset : ", average_test_loss)
print("Average Accuracy on test dataset : ", accuracy)
print('\n')
return accuracy, avg_train_loss return avg_train_loss, avg_accuracy_train
def predict(self, categories, sents_embedding, sents_mask=None): def predict(self, categories, sents_embedding, sents_mask=None):
r"""Prediction from categories output by BERT and hidden_state from BERT r"""Prediction from categories output by BERT and hidden_state from BERT
...@@ -302,7 +335,7 @@ class Linker(Module): ...@@ -302,7 +335,7 @@ class Linker(Module):
accuracy = mesure_accuracy(batch_true_links, axiom_links_pred) accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links) loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links)
return accuracy, loss return loss, accuracy
def eval_epoch(self, dataloader, cross_entropy_loss): def eval_epoch(self, dataloader, cross_entropy_loss):
r"""Average the evaluation of all the batch. r"""Average the evaluation of all the batch.
...@@ -312,14 +345,12 @@ class Linker(Module): ...@@ -312,14 +345,12 @@ class Linker(Module):
""" """
accuracy_average = 0 accuracy_average = 0
loss_average = 0 loss_average = 0
compt = 0
for step, batch in enumerate(dataloader): for step, batch in enumerate(dataloader):
compt += 1 loss, accuracy = self.eval_batch(batch, cross_entropy_loss)
accuracy, loss = self.eval_batch(batch, cross_entropy_loss)
accuracy_average += accuracy accuracy_average += accuracy
loss_average += loss loss_average += float(loss)
return accuracy_average / compt, loss_average / compt return loss_average / len(dataloader), accuracy_average / len(dataloader)
def load_weights(self, model_file): def load_weights(self, model_file):
print("#" * 15) print("#" * 15)
......
...@@ -6,7 +6,7 @@ from utils import read_csv_pgbar ...@@ -6,7 +6,7 @@ from utils import read_csv_pgbar
torch.cuda.empty_cache() torch.cuda.empty_cache()
batch_size = int(Configuration.modelTrainingConfig['batch_size']) batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 200 nb_sentences = batch_size * 20
epochs = int(Configuration.modelTrainingConfig['epoch']) epochs = int(Configuration.modelTrainingConfig['epoch'])
file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv' file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment