From 3296f9db4acc2bc584bfff8d75b0eeb4896844f9 Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Wed, 1 Jun 2022 09:42:02 +0200
Subject: [PATCH] some corrections

---
 Configuration/config.ini | 6 +++---
 Linker/Linker.py         | 5 ++++-
 2 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/Configuration/config.ini b/Configuration/config.ini
index 249f474..7bdebd8 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -11,14 +11,14 @@ max_atoms_in_one_type=510
 dim_encoder = 768
 
 [MODEL_LINKER]
-dim_cat_out=256
-dim_intermediate_FFN=128
+dim_cat_out=512
+dim_intermediate_FFN=256
 dim_pre_sinkhorn_transfo=32
 dropout=0.1
 sinkhorn_iters=3
 
 [MODEL_TRAINING]
 batch_size=32
-epoch=30
+epoch=25
 seed_val=42
 learning_rate=2e-4
diff --git a/Linker/Linker.py b/Linker/Linker.py
index 4b298f8..b1f7dbf 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -7,7 +7,7 @@ import time
 
 import torch
 import torch.nn.functional as F
-from torch.nn import Sequential, LayerNorm, Module, Linear
+from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout
 from torch.optim import AdamW
 from torch.utils.data import TensorDataset, random_split
 from torch.utils.tensorboard import SummaryWriter
@@ -61,6 +61,7 @@ class Linker(Module):
         self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
         self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
         learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
+        dropout = float(Configuration.modelTrainingConfig['dropout'])
         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
         supertagger = SuperTagger()
@@ -78,6 +79,7 @@ class Linker(Module):
 
         dim_cat = dim_encoder * 2
         self.linker_encoder = Linear(dim_cat, self.dim_cat_out, bias=False)
+        self.dropout = Dropout(dropout)
 
         self.pos_transformation = Sequential(
             FFN(self.dim_cat_out, dim_intermediate_FFN, 0.1, d_out=dim_pre_sinkhorn_transfo),
@@ -161,6 +163,7 @@ class Linker(Module):
         # cat
         atoms_sentences_encoding = torch.cat([sents_embedding_repeat, position_encoding], dim=2)
         atoms_encoding = self.linker_encoder(atoms_sentences_encoding)
+        atoms_encoding = self.dropout(atoms_encoding)
 
         # linking per atom type
         link_weights = []
-- 
GitLab