Skip to content
Snippets Groups Projects
Commit d1c8b813 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

change padding handling

parent a539008d
Branches
No related tags found
3 merge requests!6Linker with transformer,!5Linker with transformer,!3Working on padding
......@@ -4,7 +4,7 @@ transformers = 4.16.2
[DATASET_PARAMS]
symbols_vocab_size=26
atom_vocab_size=18
max_len_sentence=290
max_len_sentence=83
max_atoms_in_sentence=875
max_atoms_in_one_type=324
......@@ -12,10 +12,10 @@ max_atoms_in_one_type=324
dim_encoder = 768
[MODEL_LINKER]
nhead=8
nhead=16
dim_emb_atom = 256
dim_feedforward_transformer = 768
num_layers=2
num_layers=3
dim_cat_inter=512
dim_cat_out=256
dim_intermediate_FFN=128
......
import datetime
from torch.nn import DataParallel, Module
from Linker import *
class DataParallelModel(Module):
def __init__(self):
super().__init__()
self.linker = DataParallel(Linker("models/flaubert_super_98_V2_50e.pt"))
def forward(self, x):
x = self.linker(x)
return x
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, tensorboard=False):
r"""
Args:
df_axiom_links : pandas dataFrame containing the atoms anoted with _i
validation_rate : float
epochs : int
batch_size : int
checkpoint : boolean
tensorboard : boolean
Returns:
Final accuracy and final loss
"""
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
validation_rate)
if checkpoint or tensorboard:
checkpoint_dir, writer = output_create_dir()
for epoch_i in range(epochs):
print("")
print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
print('Training...')
avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader)
print("")
print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
if validation_rate > 0.0:
loss_test, accuracy_test = self.eval_epoch(validation_dataloader)
print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')
if checkpoint:
self.__checkpoint_save(
path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
if tensorboard:
writer.add_scalars(f'Accuracy', {
'Train': avg_accuracy_train}, epoch_i)
writer.add_scalars(f'Loss', {
'Train': avg_train_loss}, epoch_i)
if validation_rate > 0.0:
writer.add_scalars(f'Accuracy', {
'Validation': accuracy_test}, epoch_i)
writer.add_scalars(f'Loss', {
'Validation': loss_test}, epoch_i)
print('\n')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment