Skip to content
Snippets Groups Projects
Commit 3296f9db authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

some corrections

parent 62c4b4ee
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -11,14 +11,14 @@ max_atoms_in_one_type=510 ...@@ -11,14 +11,14 @@ max_atoms_in_one_type=510
dim_encoder = 768 dim_encoder = 768
[MODEL_LINKER] [MODEL_LINKER]
dim_cat_out=256 dim_cat_out=512
dim_intermediate_FFN=128 dim_intermediate_FFN=256
dim_pre_sinkhorn_transfo=32 dim_pre_sinkhorn_transfo=32
dropout=0.1 dropout=0.1
sinkhorn_iters=3 sinkhorn_iters=3
[MODEL_TRAINING] [MODEL_TRAINING]
batch_size=32 batch_size=32
epoch=30 epoch=25
seed_val=42 seed_val=42
learning_rate=2e-4 learning_rate=2e-4
...@@ -7,7 +7,7 @@ import time ...@@ -7,7 +7,7 @@ import time
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Sequential, LayerNorm, Module, Linear from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split from torch.utils.data import TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
...@@ -61,6 +61,7 @@ class Linker(Module): ...@@ -61,6 +61,7 @@ class Linker(Module):
self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type']) self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate']) learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
dropout = float(Configuration.modelTrainingConfig['dropout'])
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
supertagger = SuperTagger() supertagger = SuperTagger()
...@@ -78,6 +79,7 @@ class Linker(Module): ...@@ -78,6 +79,7 @@ class Linker(Module):
dim_cat = dim_encoder * 2 dim_cat = dim_encoder * 2
self.linker_encoder = Linear(dim_cat, self.dim_cat_out, bias=False) self.linker_encoder = Linear(dim_cat, self.dim_cat_out, bias=False)
self.dropout = Dropout(dropout)
self.pos_transformation = Sequential( self.pos_transformation = Sequential(
FFN(self.dim_cat_out, dim_intermediate_FFN, 0.1, d_out=dim_pre_sinkhorn_transfo), FFN(self.dim_cat_out, dim_intermediate_FFN, 0.1, d_out=dim_pre_sinkhorn_transfo),
...@@ -161,6 +163,7 @@ class Linker(Module): ...@@ -161,6 +163,7 @@ class Linker(Module):
# cat # cat
atoms_sentences_encoding = torch.cat([sents_embedding_repeat, position_encoding], dim=2) atoms_sentences_encoding = torch.cat([sents_embedding_repeat, position_encoding], dim=2)
atoms_encoding = self.linker_encoder(atoms_sentences_encoding) atoms_encoding = self.linker_encoder(atoms_sentences_encoding)
atoms_encoding = self.dropout(atoms_encoding)
# linking per atom type # linking per atom type
link_weights = [] link_weights = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment