Skip to content
Snippets Groups Projects
Commit 3296f9db authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

some corrections

parent 62c4b4ee
Branches
Tags
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -11,14 +11,14 @@ max_atoms_in_one_type=510 ...@@ -11,14 +11,14 @@ max_atoms_in_one_type=510
dim_encoder = 768 dim_encoder = 768
[MODEL_LINKER] [MODEL_LINKER]
dim_cat_out=256 dim_cat_out=512
dim_intermediate_FFN=128 dim_intermediate_FFN=256
dim_pre_sinkhorn_transfo=32 dim_pre_sinkhorn_transfo=32
dropout=0.1 dropout=0.1
sinkhorn_iters=3 sinkhorn_iters=3
[MODEL_TRAINING] [MODEL_TRAINING]
batch_size=32 batch_size=32
epoch=30 epoch=25
seed_val=42 seed_val=42
learning_rate=2e-4 learning_rate=2e-4
...@@ -7,7 +7,7 @@ import time ...@@ -7,7 +7,7 @@ import time
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Sequential, LayerNorm, Module, Linear from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split from torch.utils.data import TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
...@@ -61,6 +61,7 @@ class Linker(Module): ...@@ -61,6 +61,7 @@ class Linker(Module):
self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type']) self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate']) learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
dropout = float(Configuration.modelTrainingConfig['dropout'])
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
supertagger = SuperTagger() supertagger = SuperTagger()
...@@ -78,6 +79,7 @@ class Linker(Module): ...@@ -78,6 +79,7 @@ class Linker(Module):
dim_cat = dim_encoder * 2 dim_cat = dim_encoder * 2
self.linker_encoder = Linear(dim_cat, self.dim_cat_out, bias=False) self.linker_encoder = Linear(dim_cat, self.dim_cat_out, bias=False)
self.dropout = Dropout(dropout)
self.pos_transformation = Sequential( self.pos_transformation = Sequential(
FFN(self.dim_cat_out, dim_intermediate_FFN, 0.1, d_out=dim_pre_sinkhorn_transfo), FFN(self.dim_cat_out, dim_intermediate_FFN, 0.1, d_out=dim_pre_sinkhorn_transfo),
...@@ -161,6 +163,7 @@ class Linker(Module): ...@@ -161,6 +163,7 @@ class Linker(Module):
# cat # cat
atoms_sentences_encoding = torch.cat([sents_embedding_repeat, position_encoding], dim=2) atoms_sentences_encoding = torch.cat([sents_embedding_repeat, position_encoding], dim=2)
atoms_encoding = self.linker_encoder(atoms_sentences_encoding) atoms_encoding = self.linker_encoder(atoms_sentences_encoding)
atoms_encoding = self.dropout(atoms_encoding)
# linking per atom type # linking per atom type
link_weights = [] link_weights = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment