Skip to content
Snippets Groups Projects
Commit d1c8b813 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

change padding handling

parent a539008d
No related branches found
No related tags found
3 merge requests!6Linker with transformer,!5Linker with transformer,!3Working on padding
......@@ -4,7 +4,7 @@ transformers = 4.16.2
[DATASET_PARAMS]
symbols_vocab_size=26
atom_vocab_size=18
max_len_sentence=290
max_len_sentence=83
max_atoms_in_sentence=875
max_atoms_in_one_type=324
......@@ -12,10 +12,10 @@ max_atoms_in_one_type=324
dim_encoder = 768
[MODEL_LINKER]
nhead=8
nhead=16
dim_emb_atom = 256
dim_feedforward_transformer = 768
num_layers=2
num_layers=3
dim_cat_inter=512
dim_cat_out=256
dim_intermediate_FFN=128
......
import datetime
from torch.nn import DataParallel, Module
from Linker import *
class DataParallelModel(Module):
def __init__(self):
super().__init__()
self.linker = DataParallel(Linker("models/flaubert_super_98_V2_50e.pt"))
def forward(self, x):
x = self.linker(x)
return x
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, tensorboard=False):
r"""
Args:
df_axiom_links : pandas dataFrame containing the atoms anoted with _i
validation_rate : float
epochs : int
batch_size : int
checkpoint : boolean
tensorboard : boolean
Returns:
Final accuracy and final loss
"""
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
validation_rate)
if checkpoint or tensorboard:
checkpoint_dir, writer = output_create_dir()
for epoch_i in range(epochs):
print("")
print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
print('Training...')
avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader)
print("")
print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
if validation_rate > 0.0:
loss_test, accuracy_test = self.eval_epoch(validation_dataloader)
print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')
if checkpoint:
self.__checkpoint_save(
path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
if tensorboard:
writer.add_scalars(f'Accuracy', {
'Train': avg_accuracy_train}, epoch_i)
writer.add_scalars(f'Loss', {
'Train': avg_train_loss}, epoch_i)
if validation_rate > 0.0:
writer.add_scalars(f'Accuracy', {
'Validation': accuracy_test}, epoch_i)
writer.add_scalars(f'Loss', {
'Validation': loss_test}, epoch_i)
print('\n')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment