From 3296f9db4acc2bc584bfff8d75b0eeb4896844f9 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Wed, 1 Jun 2022 09:42:02 +0200 Subject: [PATCH] some corrections --- Configuration/config.ini | 6 +++--- Linker/Linker.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/Configuration/config.ini b/Configuration/config.ini index 249f474..7bdebd8 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -11,14 +11,14 @@ max_atoms_in_one_type=510 dim_encoder = 768 [MODEL_LINKER] -dim_cat_out=256 -dim_intermediate_FFN=128 +dim_cat_out=512 +dim_intermediate_FFN=256 dim_pre_sinkhorn_transfo=32 dropout=0.1 sinkhorn_iters=3 [MODEL_TRAINING] batch_size=32 -epoch=30 +epoch=25 seed_val=42 learning_rate=2e-4 diff --git a/Linker/Linker.py b/Linker/Linker.py index 4b298f8..b1f7dbf 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -7,7 +7,7 @@ import time import torch import torch.nn.functional as F -from torch.nn import Sequential, LayerNorm, Module, Linear +from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout from torch.optim import AdamW from torch.utils.data import TensorDataset, random_split from torch.utils.tensorboard import SummaryWriter @@ -61,6 +61,7 @@ class Linker(Module): self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type']) learning_rate = float(Configuration.modelTrainingConfig['learning_rate']) + dropout = float(Configuration.modelTrainingConfig['dropout']) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") supertagger = SuperTagger() @@ -78,6 +79,7 @@ class Linker(Module): dim_cat = dim_encoder * 2 self.linker_encoder = Linear(dim_cat, self.dim_cat_out, bias=False) + self.dropout = Dropout(dropout) self.pos_transformation = Sequential( FFN(self.dim_cat_out, dim_intermediate_FFN, 0.1, d_out=dim_pre_sinkhorn_transfo), @@ -161,6 +163,7 @@ class Linker(Module): # cat atoms_sentences_encoding = torch.cat([sents_embedding_repeat, position_encoding], dim=2) atoms_encoding = self.linker_encoder(atoms_sentences_encoding) + atoms_encoding = self.dropout(atoms_encoding) # linking per atom type link_weights = [] -- GitLab