Skip to content
Snippets Groups Projects
Commit 2889028a authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

change embedding

parent 66d4667d
Branches
No related tags found
3 merge requests!6Linker with transformer,!5Linker with transformer,!2Change preprocess
......@@ -12,7 +12,7 @@ max_atoms_in_one_type=510
dim_encoder = 768
[MODEL_DECODER]
nhead=8
nhead=4
num_layers=1
dropout=0.1
dim_feedforward=512
......
import torch
from torch.nn import Module, Embedding
class AtomEmbedding(Module):
def __init__(self, dim_linker, atom_vocab_size, padding_idx=None):
super(AtomEmbedding, self).__init__()
self.emb = Embedding(num_embeddings=atom_vocab_size, embedding_dim=dim_linker, padding_idx=padding_idx,
scale_grad_by_freq=True)
def forward(self, x):
return self.emb(x)
......@@ -6,14 +6,13 @@ import datetime
import time
import torch.nn.functional as F
from torch.nn import Sequential, LayerNorm, Dropout
from torch.nn import Sequential, LayerNorm, Dropout, Embedding
from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from Configuration import Configuration
from Linker.AtomEmbedding import AtomEmbedding
from Linker.AtomTokenizer import AtomTokenizer
from Linker.MHA import AttentionDecoderLayer
from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
......@@ -21,7 +20,6 @@ from Linker.atom_map import atom_map, atom_map_redux
from Linker.eval import mesure_accuracy, SinkhornLoss
from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx
from Supertagger import *
from utils import pad_sequence
def format_time(elapsed):
......@@ -62,7 +60,7 @@ class Linker(Module):
atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
self.dropout = Dropout(0.1)
self.device = "cpu"
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
supertagger = SuperTagger()
supertagger.load_weights(supertagger_path_model)
......@@ -73,7 +71,9 @@ class Linker(Module):
self.padding_id = self.atom_map['[PAD]']
self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
self.inverse_map = self.atoms_tokenizer.inverse_atom_map
self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, atom_vocab_size, self.padding_id)
self.atoms_embedding = Embedding(num_embeddings=atom_vocab_size, embedding_dim=self.dim_embedding_atoms,
padding_idx=self.padding_id,
scale_grad_by_freq=True)
self.linker_encoder = AttentionDecoderLayer()
......@@ -90,8 +90,6 @@ class Linker(Module):
self.optimizer = AdamW(self.parameters(),
lr=learning_rate)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
......@@ -171,9 +169,8 @@ class Linker(Module):
link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters))
total_link_weights = torch.stack(link_weights)
link_weights_per_batch = total_link_weights.permute(1, 0, 2, 3)
return F.log_softmax(link_weights_per_batch, dim=3)
return F.log_softmax(total_link_weights, dim=3)
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, tensorboard=False):
......
from .Linker import Linker
from .atom_map import atom_map
from .AtomEmbedding import AtomEmbedding
from .AtomTokenizer import AtomTokenizer
\ No newline at end of file
......@@ -9,14 +9,15 @@ class SinkhornLoss(Module):
def forward(self, predictions, truths):
return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1)
for link, perm in zip(predictions, truths))
for link, perm in zip(predictions, truths.permute(1, 0, 2)))
def mesure_accuracy(batch_true_links, axiom_links_pred):
r"""
batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
axiom_links_pred : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
"""
batch_true_links=batch_true_links.permute(1, 0, 2)
correct_links = torch.ones(axiom_links_pred.size())
correct_links[axiom_links_pred != batch_true_links] = 0
correct_links[batch_true_links == -1] = 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment