From 2889028aaa79046c9d885cbcd334fd99c43d6a5c Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Wed, 25 May 2022 15:44:11 +0200
Subject: [PATCH] change embedding

---
 Configuration/config.ini |  2 +-
 Linker/AtomEmbedding.py  | 12 ------------
 Linker/Linker.py         | 15 ++++++---------
 Linker/__init__.py       |  1 -
 Linker/eval.py           |  7 ++++---
 5 files changed, 11 insertions(+), 26 deletions(-)
 delete mode 100644 Linker/AtomEmbedding.py

diff --git a/Configuration/config.ini b/Configuration/config.ini
index 69d1a5c..ea8dd69 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -12,7 +12,7 @@ max_atoms_in_one_type=510
 dim_encoder = 768
 
 [MODEL_DECODER]
-nhead=8
+nhead=4
 num_layers=1
 dropout=0.1
 dim_feedforward=512
diff --git a/Linker/AtomEmbedding.py b/Linker/AtomEmbedding.py
deleted file mode 100644
index e7be599..0000000
--- a/Linker/AtomEmbedding.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import torch
-from torch.nn import Module, Embedding
-
-
-class AtomEmbedding(Module):
-    def __init__(self, dim_linker, atom_vocab_size, padding_idx=None):
-        super(AtomEmbedding, self).__init__()
-        self.emb = Embedding(num_embeddings=atom_vocab_size, embedding_dim=dim_linker, padding_idx=padding_idx,
-                             scale_grad_by_freq=True)
-
-    def forward(self, x):
-        return self.emb(x)
diff --git a/Linker/Linker.py b/Linker/Linker.py
index dc3e6ee..611575d 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -6,14 +6,13 @@ import datetime
 import time
 
 import torch.nn.functional as F
-from torch.nn import Sequential, LayerNorm, Dropout
+from torch.nn import Sequential, LayerNorm, Dropout, Embedding
 from torch.optim import AdamW
 from torch.utils.data import TensorDataset, random_split
 from torch.utils.tensorboard import SummaryWriter
 from tqdm import tqdm
 
 from Configuration import Configuration
-from Linker.AtomEmbedding import AtomEmbedding
 from Linker.AtomTokenizer import AtomTokenizer
 from Linker.MHA import AttentionDecoderLayer
 from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
@@ -21,7 +20,6 @@ from Linker.atom_map import atom_map, atom_map_redux
 from Linker.eval import mesure_accuracy, SinkhornLoss
 from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx
 from Supertagger import *
-from utils import pad_sequence
 
 
 def format_time(elapsed):
@@ -62,7 +60,7 @@ class Linker(Module):
         atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
         learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
         self.dropout = Dropout(0.1)
-        self.device = "cpu"
+        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
         supertagger = SuperTagger()
         supertagger.load_weights(supertagger_path_model)
@@ -73,7 +71,9 @@ class Linker(Module):
         self.padding_id = self.atom_map['[PAD]']
         self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
         self.inverse_map = self.atoms_tokenizer.inverse_atom_map
-        self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, atom_vocab_size, self.padding_id)
+        self.atoms_embedding = Embedding(num_embeddings=atom_vocab_size, embedding_dim=self.dim_embedding_atoms,
+                                         padding_idx=self.padding_id,
+                                         scale_grad_by_freq=True)
 
         self.linker_encoder = AttentionDecoderLayer()
 
@@ -90,8 +90,6 @@ class Linker(Module):
         self.optimizer = AdamW(self.parameters(),
                                lr=learning_rate)
 
-        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
         self.to(self.device)
 
     def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
@@ -171,9 +169,8 @@ class Linker(Module):
             link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters))
 
         total_link_weights = torch.stack(link_weights)
-        link_weights_per_batch = total_link_weights.permute(1, 0, 2, 3)
 
-        return F.log_softmax(link_weights_per_batch, dim=3)
+        return F.log_softmax(total_link_weights, dim=3)
 
     def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
                      batch_size=32, checkpoint=True, tensorboard=False):
diff --git a/Linker/__init__.py b/Linker/__init__.py
index b9380b4..92c67b3 100644
--- a/Linker/__init__.py
+++ b/Linker/__init__.py
@@ -1,4 +1,3 @@
 from .Linker import Linker
 from .atom_map import atom_map
-from .AtomEmbedding import AtomEmbedding
 from .AtomTokenizer import AtomTokenizer
\ No newline at end of file
diff --git a/Linker/eval.py b/Linker/eval.py
index 1113596..e713120 100644
--- a/Linker/eval.py
+++ b/Linker/eval.py
@@ -9,14 +9,15 @@ class SinkhornLoss(Module):
 
     def forward(self, predictions, truths):
         return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1)
-                   for link, perm in zip(predictions, truths))
+                   for link, perm in zip(predictions, truths.permute(1, 0, 2)))
 
 
 def mesure_accuracy(batch_true_links, axiom_links_pred):
     r"""
-    batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
-    axiom_links_pred : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
+    batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
+    axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
     """
+    batch_true_links=batch_true_links.permute(1, 0, 2)
     correct_links = torch.ones(axiom_links_pred.size())
     correct_links[axiom_links_pred != batch_true_links] = 0
     correct_links[batch_true_links == -1] = 1
-- 
GitLab