diff --git a/Configuration/config.ini b/Configuration/config.ini index 69d1a5c600d73b737e8cfd1215cde5379f1cda52..ea8dd6979eeed3bf6cbd6f77724329ce65273c9f 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -12,7 +12,7 @@ max_atoms_in_one_type=510 dim_encoder = 768 [MODEL_DECODER] -nhead=8 +nhead=4 num_layers=1 dropout=0.1 dim_feedforward=512 diff --git a/Linker/AtomEmbedding.py b/Linker/AtomEmbedding.py deleted file mode 100644 index e7be599a0fa145f76a5646b83973a3501ed52d4d..0000000000000000000000000000000000000000 --- a/Linker/AtomEmbedding.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch -from torch.nn import Module, Embedding - - -class AtomEmbedding(Module): - def __init__(self, dim_linker, atom_vocab_size, padding_idx=None): - super(AtomEmbedding, self).__init__() - self.emb = Embedding(num_embeddings=atom_vocab_size, embedding_dim=dim_linker, padding_idx=padding_idx, - scale_grad_by_freq=True) - - def forward(self, x): - return self.emb(x) diff --git a/Linker/Linker.py b/Linker/Linker.py index dc3e6ee63b469db642ead6bbb8f0520e617652a7..611575d7140c48516b4ce3f2cbd55bf995e04610 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -6,14 +6,13 @@ import datetime import time import torch.nn.functional as F -from torch.nn import Sequential, LayerNorm, Dropout +from torch.nn import Sequential, LayerNorm, Dropout, Embedding from torch.optim import AdamW from torch.utils.data import TensorDataset, random_split from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from Configuration import Configuration -from Linker.AtomEmbedding import AtomEmbedding from Linker.AtomTokenizer import AtomTokenizer from Linker.MHA import AttentionDecoderLayer from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn @@ -21,7 +20,6 @@ from Linker.atom_map import atom_map, atom_map_redux from Linker.eval import mesure_accuracy, SinkhornLoss from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx from Supertagger import * -from utils import pad_sequence def format_time(elapsed): @@ -62,7 +60,7 @@ class Linker(Module): atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) learning_rate = float(Configuration.modelTrainingConfig['learning_rate']) self.dropout = Dropout(0.1) - self.device = "cpu" + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") supertagger = SuperTagger() supertagger.load_weights(supertagger_path_model) @@ -73,7 +71,9 @@ class Linker(Module): self.padding_id = self.atom_map['[PAD]'] self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) self.inverse_map = self.atoms_tokenizer.inverse_atom_map - self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, atom_vocab_size, self.padding_id) + self.atoms_embedding = Embedding(num_embeddings=atom_vocab_size, embedding_dim=self.dim_embedding_atoms, + padding_idx=self.padding_id, + scale_grad_by_freq=True) self.linker_encoder = AttentionDecoderLayer() @@ -90,8 +90,6 @@ class Linker(Module): self.optimizer = AdamW(self.parameters(), lr=learning_rate) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.to(self.device) def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1): @@ -171,9 +169,8 @@ class Linker(Module): link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters)) total_link_weights = torch.stack(link_weights) - link_weights_per_batch = total_link_weights.permute(1, 0, 2, 3) - return F.log_softmax(link_weights_per_batch, dim=3) + return F.log_softmax(total_link_weights, dim=3) def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, batch_size=32, checkpoint=True, tensorboard=False): diff --git a/Linker/__init__.py b/Linker/__init__.py index b9380b473249eb38f1995474414ceb1eb6ea85ca..92c67b3fcaa9d1121107b979ba57a5bbeba043ea 100644 --- a/Linker/__init__.py +++ b/Linker/__init__.py @@ -1,4 +1,3 @@ from .Linker import Linker from .atom_map import atom_map -from .AtomEmbedding import AtomEmbedding from .AtomTokenizer import AtomTokenizer \ No newline at end of file diff --git a/Linker/eval.py b/Linker/eval.py index 1113596e276a190edfc49ac50ce511ad64b4e6c8..e713120ce61d3a43619559bd2eaadf867a958931 100644 --- a/Linker/eval.py +++ b/Linker/eval.py @@ -9,14 +9,15 @@ class SinkhornLoss(Module): def forward(self, predictions, truths): return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1) - for link, perm in zip(predictions, truths)) + for link, perm in zip(predictions, truths.permute(1, 0, 2))) def mesure_accuracy(batch_true_links, axiom_links_pred): r""" - batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms - axiom_links_pred : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms + batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms + axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms """ + batch_true_links=batch_true_links.permute(1, 0, 2) correct_links = torch.ones(axiom_links_pred.size()) correct_links[axiom_links_pred != batch_true_links] = 0 correct_links[batch_true_links == -1] = 1