Skip to content
Snippets Groups Projects
Commit 2889028a authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

change embedding

parent 66d4667d
No related branches found
No related tags found
3 merge requests!6Linker with transformer,!5Linker with transformer,!2Change preprocess
......@@ -12,7 +12,7 @@ max_atoms_in_one_type=510
dim_encoder = 768
[MODEL_DECODER]
nhead=8
nhead=4
num_layers=1
dropout=0.1
dim_feedforward=512
......
import torch
from torch.nn import Module, Embedding
class AtomEmbedding(Module):
def __init__(self, dim_linker, atom_vocab_size, padding_idx=None):
super(AtomEmbedding, self).__init__()
self.emb = Embedding(num_embeddings=atom_vocab_size, embedding_dim=dim_linker, padding_idx=padding_idx,
scale_grad_by_freq=True)
def forward(self, x):
return self.emb(x)
......@@ -6,14 +6,13 @@ import datetime
import time
import torch.nn.functional as F
from torch.nn import Sequential, LayerNorm, Dropout
from torch.nn import Sequential, LayerNorm, Dropout, Embedding
from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from Configuration import Configuration
from Linker.AtomEmbedding import AtomEmbedding
from Linker.AtomTokenizer import AtomTokenizer
from Linker.MHA import AttentionDecoderLayer
from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
......@@ -21,7 +20,6 @@ from Linker.atom_map import atom_map, atom_map_redux
from Linker.eval import mesure_accuracy, SinkhornLoss
from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx
from Supertagger import *
from utils import pad_sequence
def format_time(elapsed):
......@@ -62,7 +60,7 @@ class Linker(Module):
atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
self.dropout = Dropout(0.1)
self.device = "cpu"
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
supertagger = SuperTagger()
supertagger.load_weights(supertagger_path_model)
......@@ -73,7 +71,9 @@ class Linker(Module):
self.padding_id = self.atom_map['[PAD]']
self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
self.inverse_map = self.atoms_tokenizer.inverse_atom_map
self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, atom_vocab_size, self.padding_id)
self.atoms_embedding = Embedding(num_embeddings=atom_vocab_size, embedding_dim=self.dim_embedding_atoms,
padding_idx=self.padding_id,
scale_grad_by_freq=True)
self.linker_encoder = AttentionDecoderLayer()
......@@ -90,8 +90,6 @@ class Linker(Module):
self.optimizer = AdamW(self.parameters(),
lr=learning_rate)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
......@@ -171,9 +169,8 @@ class Linker(Module):
link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters))
total_link_weights = torch.stack(link_weights)
link_weights_per_batch = total_link_weights.permute(1, 0, 2, 3)
return F.log_softmax(link_weights_per_batch, dim=3)
return F.log_softmax(total_link_weights, dim=3)
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, tensorboard=False):
......
from .Linker import Linker
from .atom_map import atom_map
from .AtomEmbedding import AtomEmbedding
from .AtomTokenizer import AtomTokenizer
\ No newline at end of file
......@@ -9,14 +9,15 @@ class SinkhornLoss(Module):
def forward(self, predictions, truths):
return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1)
for link, perm in zip(predictions, truths))
for link, perm in zip(predictions, truths.permute(1, 0, 2)))
def mesure_accuracy(batch_true_links, axiom_links_pred):
r"""
batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
axiom_links_pred : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
"""
batch_true_links=batch_true_links.permute(1, 0, 2)
correct_links = torch.ones(axiom_links_pred.size())
correct_links[axiom_links_pred != batch_true_links] = 0
correct_links[batch_true_links == -1] = 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment