From 0eb359e3ff924afe1e2fa5d52f4973e1d28db82a Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Wed, 4 May 2022 11:44:34 +0200
Subject: [PATCH] progress on linker

---
 Configuration/Configuration.py         |   3 +
 Configuration/config.ini               |  22 ++++-
 SuperTagger/Decoder/RNNDecoderLayer.py |   6 +-
 SuperTagger/EncoderDecoder.py          |   4 +-
 SuperTagger/Linker/AttentionLayer.py   | 109 +++++++++++++++++++++++++
 SuperTagger/Linker/Linker.py           |  52 ++++++------
 6 files changed, 162 insertions(+), 34 deletions(-)
 create mode 100644 SuperTagger/Linker/AttentionLayer.py

diff --git a/Configuration/Configuration.py b/Configuration/Configuration.py
index 3d94c9b..9be7829 100644
--- a/Configuration/Configuration.py
+++ b/Configuration/Configuration.py
@@ -11,7 +11,10 @@ config.read(path_config_file)
 
 version = config["VERSION"]
 
+datasetConfig = config["DATASET_PARAMS"]
+modelEncoderConfig = config["MODEL_ENCODER"]
 modelDecoderConfig = config["MODEL_DECODER"]
+modelLinkerConfig = config["MODEL_LINKER"]
 modelTrainingConfig = config["MODEL_TRAINING"]
 
 # endregion Get section
diff --git a/Configuration/config.ini b/Configuration/config.ini
index 7e8b403..1d2d4b0 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -1,13 +1,31 @@
 [VERSION]
 transformers = 4.16.2
+
+[DATASET_PARAMS]
+symbols_vocab_size=26
+atom_vocab_size=12
+max_len_sentence=148
+max_symbols_in_sentence=1250
+
+[MODEL_ENCODER]
+dim_encoder = 768
+
 [MODEL_DECODER]
 dim_encoder = 768
 dim_decoder = 8
 num_rnn_layers=1
 dropout=0.1
 teacher_forcing=0.05
-symbols_vocab_size=26
-max_len_sentence=148
+
+[MODEL_LINKER]
+nhead=8
+dim_feedforward=246
+dim_embedding_atoms=8
+dim_polarity_transfo=128
+layer_norm_eps=1e-5
+dropout=0.1
+sinkhorn_iters=3
+
 [MODEL_TRAINING]
 device=cpu
 batch_size=32
diff --git a/SuperTagger/Decoder/RNNDecoderLayer.py b/SuperTagger/Decoder/RNNDecoderLayer.py
index 9c6c12b..93e96a6 100644
--- a/SuperTagger/Decoder/RNNDecoderLayer.py
+++ b/SuperTagger/Decoder/RNNDecoderLayer.py
@@ -13,13 +13,13 @@ class RNNDecoderLayer(Module):
         super(RNNDecoderLayer, self).__init__()
 
         # init params
-        self.dim_encoder = int(Configuration.modelDecoderConfig['dim_encoder'])
+        self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
         self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
-        self.max_len_sentence = int(Configuration.modelDecoderConfig['max_len_sentence'])
-        self.symbols_vocab_size = int(Configuration.modelDecoderConfig['symbols_vocab_size'])
         dropout = float(Configuration.modelDecoderConfig['dropout'])
         self.num_rnn_layers = int(Configuration.modelDecoderConfig['num_rnn_layers'])
         self.teacher_forcing = float(Configuration.modelDecoderConfig['teacher_forcing'])
+        self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
+        self.symbols_vocab_size = int(Configuration.datasetConfig['symbols_vocab_size'])
 
         self.bidirectional = False
         self.use_attention = True
diff --git a/SuperTagger/EncoderDecoder.py b/SuperTagger/EncoderDecoder.py
index 0519bc2..36311d5 100644
--- a/SuperTagger/EncoderDecoder.py
+++ b/SuperTagger/EncoderDecoder.py
@@ -19,8 +19,8 @@ class EncoderDecoder(Module):
     def __init__(self, BASE_TOKENIZER, BASE_MODEL, symbols_map):
         super(EncoderDecoder, self).__init__()
 
-        self.max_len_sentence = int(Configuration.modelDecoderConfig['max_len_sentence'])
-        self.max_symbols_in_sentence = int(Configuration.modelDecoderConfig['max_symbols_in_sentence'])
+        self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
+        self.max_symbols_in_sentence = int(Configuration.datasetConfig['max_symbols_in_sentence'])
         self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
 
         self.symbols_map = symbols_map
diff --git a/SuperTagger/Linker/AttentionLayer.py b/SuperTagger/Linker/AttentionLayer.py
new file mode 100644
index 0000000..150df88
--- /dev/null
+++ b/SuperTagger/Linker/AttentionLayer.py
@@ -0,0 +1,109 @@
+from torch import Tensor
+import torch
+from torch.nn import (GELU, Dropout, LayerNorm, Linear, Module, MultiheadAttention,
+                      Sequential)
+
+from Configuration import Configuration
+from SuperTagger.Symbol.SymbolEmbedding import SymbolEmbedding
+
+
+class FFN(Module):
+    "Implements FFN equation."
+
+    def __init__(self, d_model, d_ff, dropout=0.1):
+        super(FFN, self).__init__()
+        self.ffn = Sequential(
+            Linear(d_model, d_ff, bias=False),
+            GELU(),
+            Dropout(dropout),
+            Linear(d_ff, d_model, bias=False)
+        )
+
+    def forward(self, x):
+        return self.ffn(x)
+
+
+class AttentionLayer(Module):
+    r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
+    This standard decoder layer is based on the paper "Attention Is All You Need".
+    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
+    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
+    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
+    in a different way during application.
+
+    Args:
+        dim_model: the number of expected features in the input (required).
+        nhead: the number of heads in the multiheadattention models (required).
+        dim_feedforward: the dimension of the feedforward network model (default=2048).
+        dropout: the dropout value (default=0.1).
+        activation: the activation function of the intermediate layer, can be a string
+            ("relu" or "gelu") or a unary callable. Default: relu
+        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
+        batch_first: If ``True``, then the input and output tensors are provided
+            as (batch, seq, feature). Default: ``False``.
+        norm_first: if ``True``, layer norm is done prior to self attention, multihead
+            attention and feedforward operations, respectivaly. Otherwise it's done after.
+            Default: ``False`` (after).
+    """
+    __constants__ = ['batch_first', 'norm_first']
+
+    def __init__(self) -> None:
+        super(AttentionLayer, self).__init__()
+
+        # init params
+        dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
+        dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms'])
+        dim_feedforward = int(Configuration.modelLinkerConfig['dim_feedforward'])
+        dropout = float(Configuration.modelLinkerConfig['dropout'])
+        layer_norm_eps = float(Configuration.modelLinkerConfig['layer_norm_eps'])
+        self.nhead = int(Configuration.modelLinkerConfig['nhead'])
+        self.max_symbols_in_sentence = int(Configuration.datasetConfig['max_symbols_in_sentence'])
+
+        self.symbols_embedder = SymbolEmbedding(self.dim_embedding_atoms, self.symbols_vocab_size)
+
+        # layers
+        self.dropout = Dropout(dropout)
+        self.self_attn = MultiheadAttention(dim_embedding_atoms, self.nhead, dropout=dropout, batch_first=True,
+                                            kdim=dim_embedding_atoms, vdim=dim_embedding_atoms)
+        self.norm1 = LayerNorm(dim_embedding_atoms, eps=layer_norm_eps)
+        self.multihead_attn = MultiheadAttention(dim_embedding_atoms, self.nhead, dropout=dropout,
+                                                 kdim=dim_encoder, vdim=dim_encoder,
+                                                 batch_first=True)
+        self.norm2 = LayerNorm(dim_embedding_atoms, eps=layer_norm_eps)
+        self.ffn = FFN(d_model=dim_embedding_atoms, d_ff=dim_feedforward, dropout=dropout)
+        self.norm3 = LayerNorm(dim_embedding_atoms, eps=layer_norm_eps)
+
+    def forward(self, atoms_embeddings, sents_embedding, encoder_mask, decoder_mask):
+        r"""Pass the inputs through the decoder layer.
+
+        Args:
+            atoms: the sequence to the decoder layer (required).
+            sents: the sequence from the last layer of the encoder (required).
+        """
+        x = atoms_embeddings
+        x = self.norm1(x + self._mask_mha_block(x, decoder_mask))
+        x = self.norm2(x + self._mha_block(x, sents_embedding, encoder_mask))
+        x = self.norm3(x + self._ff_block(x))
+
+        return x
+
+    # self-attention block
+    def _mask_mha_block(self, x: Tensor, decoder_mask: Tensor) -> Tensor:
+        if decoder_mask is not None:
+            # Same mask applied to all h heads.
+            decoder_mask = decoder_mask.repeat(self.nhead, 1, 1)
+        x = self.self_attn(x, x, x, attn_mask=decoder_mask)[0]
+        return x
+
+    # multihead attention block
+    def _mha_block(self, x: Tensor, sents_embs: Tensor, encoder_mask: Tensor) -> Tensor:
+        if encoder_mask is not None:
+            # Same mask applied to all h heads.
+            encoder_mask = encoder_mask.repeat(self.nhead, 1, 1)
+        x = self.multihead_attn(x, sents_embs, sents_embs, attn_mask=encoder_mask)[0]
+        return x
+
+    # feed forward block
+    def _ff_block(self, x: Tensor) -> Tensor:
+        x = self.ffn.forward(x)
+        return x
diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py
index 745b7d9..6b5c6f1 100644
--- a/SuperTagger/Linker/Linker.py
+++ b/SuperTagger/Linker/Linker.py
@@ -9,50 +9,50 @@ from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
 from SuperTagger.Linker.atom_map import atom_map
 from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
 from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch
+from SuperTagger.Linker.AttentionLayer import FFN, AttentionLayer
 
 
-def FFN(d_model, d_ff, dropout_rate=0.1, d_out=None):
-    return Sequential(
-        Linear(d_model, d_ff, bias=False),
-        GELU(),
-        Dropout(dropout_rate),
-        Linear(d_ff, d_model if d_out is None else d_out, bias=False)
-    )
-
 
 class Linker:
     def __init__(self):
         self.__init__()
 
-        self.dim_encoder = int(Configuration.modelDecoderConfig['dim_encoder'])
-        self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
-        self.dim_linker = int(Configuration.modelDecoderConfig['dim_linker'])
-        self.max_atoms_in_sentence = int(Configuration.modelDecoderConfig['max_atoms_in_sentence'])
-        self.atom_vocab_size = int(Configuration.modelDecoderConfig['atom_vocab_size'])
-
+        self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
+        self.dim_polarity_transfo = int(Configuration.modelLinkerConfig['dim_polarity_transfo'])
+        self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms'])
+        self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
+        self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
+        self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
         self.dropout = Dropout(0.1)
 
         self.atom_map = atom_map
         self.padding_id = self.atom_map['[PAD]']
         self.atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
-        self.atom_embedding = AtomEmbedding(self.dim_linker, self.atom_vocab_size, self.padding_id)
+        self.atom_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id)
 
         # to do : definit un encoding
-        self.linker_encoder = FFN(self.dim_linker, self.dim_linker, 0.1)
+        self.linker_encoder = AttentionLayer()
 
         self.pos_transformation = Sequential(
-            FFN(self.dim_decoder, self.dim_decoder, 0.1),
-            LayerNorm(self.dim_decoder, eps=1e-12)
+            FFN(self.dim_polarity_transfo, self.dim_polarity_transfo, 0.1),
+            LayerNorm(self.dim_polarity_transfo, eps=1e-12)
         )
         self.neg_transformation = Sequential(
-            FFN(self.dim_decoder, self.dim_decoder, 0.1),
-            LayerNorm(self.dim_decoder, eps=1e-12)
+            FFN(self.dim_polarity_transfo, self.dim_polarity_transfo, 0.1),
+            LayerNorm(self.dim_polarity_transfo, eps=1e-12)
         )
 
-    def forward(self, category_batch):
+    def make_decoder_mask(self, atoms_batch) :
+        decoder_attn_mask = torch.ones_like(atoms_batch, dtype=torch.float64)
+        decoder_attn_mask[atoms_batch.eq(self.padding_id)] = 0.0
+        return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_batch.shape[1], 1)
+
+    def forward(self, category_batch, sents_embedding, sents_mask):
         '''
         Parameters :
-        symbols_decoding : batch of size (batch_size, sequence_length) = output of decoder
+        category_batch : batch of size (batch_size, sequence_length) = output of decoder
+        sents_embedding
+        sents_mask
         Retturns :
         link_weights : batch-size, atom_vocab_size, ...)
         '''
@@ -63,10 +63,8 @@ class Linker:
         atoms_embedding = self.atom_embedding(atoms_batch)
 
         # MHA ou LSTM avec sortie de BERT
-        #
-        # TO DO
-        # atoms_encoding = self.linker_encoder(atoms_embedding)
-        #
+        # decoder_mask = self.make_decoder_mask(atoms_batch)
+        # atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, decoder_mask)
         atoms_encoding = atoms_embedding
 
         # find atoms polarity : list (not tensor) (batch_size, max_atoms_in sentence)
@@ -87,6 +85,6 @@ class Linker:
                 neg_encoding = self.neg_transformation(neg_encoding)
 
                 weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0))
-                link_weights.append(sinkhorn(weights, iters=3))
+                link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters))
 
         return link_weights
-- 
GitLab