Skip to content
Snippets Groups Projects
Commit 8c037d89 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

correction

parent d1c8b813
Branches
No related tags found
3 merge requests!6Linker with transformer,!5Linker with transformer,!3Working on padding
......@@ -4,7 +4,7 @@ transformers = 4.16.2
[DATASET_PARAMS]
symbols_vocab_size=26
atom_vocab_size=18
max_len_sentence=83
max_len_sentence=290
max_atoms_in_sentence=875
max_atoms_in_one_type=324
......@@ -12,7 +12,7 @@ max_atoms_in_one_type=324
dim_encoder = 768
[MODEL_LINKER]
nhead=16
nhead=8
dim_emb_atom = 256
dim_feedforward_transformer = 768
num_layers=3
......@@ -25,6 +25,6 @@ sinkhorn_iters=5
[MODEL_TRAINING]
batch_size=32
epoch=25
epoch=30
seed_val=42
learning_rate=2e-3
\ No newline at end of file
import datetime
from torch.nn import DataParallel, Module
from Linker import *
class DataParallelModel(Module):
def __init__(self):
super().__init__()
self.linker = DataParallel(Linker("models/flaubert_super_98_V2_50e.pt"))
def forward(self, x):
x = self.linker(x)
return x
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, tensorboard=False):
r"""
Args:
df_axiom_links : pandas dataFrame containing the atoms anoted with _i
validation_rate : float
epochs : int
batch_size : int
checkpoint : boolean
tensorboard : boolean
Returns:
Final accuracy and final loss
"""
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
validation_rate)
if checkpoint or tensorboard:
checkpoint_dir, writer = output_create_dir()
for epoch_i in range(epochs):
print("")
print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
print('Training...')
avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader)
print("")
print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
if validation_rate > 0.0:
loss_test, accuracy_test = self.eval_epoch(validation_dataloader)
print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')
if checkpoint:
self.__checkpoint_save(
path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
if tensorboard:
writer.add_scalars(f'Accuracy', {
'Train': avg_accuracy_train}, epoch_i)
writer.add_scalars(f'Loss', {
'Train': avg_train_loss}, epoch_i)
if validation_rate > 0.0:
writer.add_scalars(f'Accuracy', {
'Validation': accuracy_test}, epoch_i)
writer.add_scalars(f'Loss', {
'Validation': loss_test}, epoch_i)
print('\n')
\ No newline at end of file
......@@ -22,7 +22,7 @@ from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from Linker.AtomTokenizer import AtomTokenizer
from Linker.atom_map import atom_map, atom_map_redux
from Linker.eval import mesure_accuracy, SinkhornLoss
from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch
from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch, get_neg_idx
from Supertagger import SuperTagger
from utils import pad_sequence
......@@ -149,7 +149,7 @@ class Linker(Module):
num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence)
pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
neg_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
df_axiom_links["Y"])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment