Skip to content
Snippets Groups Projects
Commit 85a4adab authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

adding mha

parent 8b0f5bb5
Branches
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -18,7 +18,7 @@ dropout=0.1 ...@@ -18,7 +18,7 @@ dropout=0.1
teacher_forcing=0.05 teacher_forcing=0.05
[MODEL_LINKER] [MODEL_LINKER]
nhead=8 nhead=1
dim_feedforward=246 dim_feedforward=246
dim_embedding_atoms=8 dim_embedding_atoms=8
dim_polarity_transfo=128 dim_polarity_transfo=128
......
from itertools import chain from itertools import chain
import torch import torch
from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU, MultiheadAttention
from torch.nn import Module from torch.nn import Module
import torch.nn.functional as F import torch.nn.functional as F
from Configuration import Configuration from Configuration import Configuration
from SuperTagger.Linker.AtomEmbedding import AtomEmbedding from SuperTagger.Linker.AtomEmbedding import AtomEmbedding
from SuperTagger.Linker.AtomTokenizer import AtomTokenizer from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
from SuperTagger.Linker.MHA import AttentionDecoderLayer
from SuperTagger.Linker.atom_map import atom_map from SuperTagger.Linker.atom_map import atom_map
from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch, FFN
from SuperTagger.eval import mesure_accuracy from SuperTagger.eval import mesure_accuracy
from SuperTagger.utils import pad_sequence from SuperTagger.utils import pad_sequence
class FFN(Module):
"Implements FFN equation."
def __init__(self, d_model, d_ff, dropout=0.1):
super(FFN, self).__init__()
self.ffn = Sequential(
Linear(d_model, d_ff, bias=False),
GELU(),
Dropout(dropout),
Linear(d_ff, d_model, bias=False)
)
def forward(self, x):
return self.ffn(x)
class Linker(Module): class Linker(Module):
def __init__(self): def __init__(self):
super(Linker, self).__init__() super(Linker, self).__init__()
...@@ -39,6 +24,8 @@ class Linker(Module): ...@@ -39,6 +24,8 @@ class Linker(Module):
self.dim_polarity_transfo = int(Configuration.modelLinkerConfig['dim_polarity_transfo']) self.dim_polarity_transfo = int(Configuration.modelLinkerConfig['dim_polarity_transfo'])
self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms']) self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms'])
self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters']) self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
self.nhead = int(Configuration.modelLinkerConfig['nhead'])
self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type']) self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
...@@ -50,7 +37,7 @@ class Linker(Module): ...@@ -50,7 +37,7 @@ class Linker(Module):
self.atom_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id) self.atom_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id)
# to do : definit un encoding # to do : definit un encoding
# self.linker_encoder = self.linker_encoder = AttentionDecoderLayer()
self.pos_transformation = Sequential( self.pos_transformation = Sequential(
FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1), FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1),
...@@ -61,23 +48,32 @@ class Linker(Module): ...@@ -61,23 +48,32 @@ class Linker(Module):
LayerNorm(self.dim_embedding_atoms, eps=1e-12) LayerNorm(self.dim_embedding_atoms, eps=1e-12)
) )
def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding): def make_decoder_mask(self, atoms_token) :
decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64)
decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0
return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1)
def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None):
r''' r'''
Parameters : Parameters :
atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories
atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities
sents_embedding : output of BERT for context sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
sents_mask
Returns : Returns :
link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat)
''' '''
# atoms embedding # atoms embedding
atoms_embedding = self.atom_embedding(atoms_batch_tokenized) atoms_embedding = self.atom_embedding(atoms_batch_tokenized)
print(atoms_embedding.shape)
# MHA ou LSTM avec sortie de BERT # MHA ou LSTM avec sortie de BERT
# decoder_mask = self.make_decoder_mask(atoms_batch) sents_embedding = torch.randn(32, self.max_len_sentence, self.dim_encoder)
# atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, decoder_mask) batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape
atoms_encoding = atoms_embedding sents_mask = torch.randn(batch_size * self.nhead, self.max_atoms_in_sentence, self.max_len_sentence)
atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, self.make_decoder_mask(atoms_batch_tokenized))
#atoms_encoding = atoms_embedding
link_weights = [] link_weights = []
for atom_type in list(self.atom_map.keys())[:-1]: for atom_type in list(self.atom_map.keys())[:-1]:
......
import copy
import torch
import torch.nn.functional as F
import torch.optim as optim
from Configuration import Configuration
from torch import Tensor, LongTensor
from torch.nn import (GELU, LSTM, Dropout, LayerNorm, Linear, Module, MultiheadAttention,
ModuleList, Sequential)
from SuperTagger.Linker.utils import FFN
class AttentionDecoderLayer(Module):
r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
This standard decoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
dim_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False``.
norm_first: if ``True``, layer norm is done prior to self attention, multihead
attention and feedforward operations, respectivaly. Otherwise it's done after.
Default: ``False`` (after).
"""
__constants__ = ['batch_first', 'norm_first']
def __init__(self) -> None:
super(AttentionDecoderLayer, self).__init__()
# init params
dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
nhead = int(Configuration.modelLinkerConfig['nhead'])
dropout = float(Configuration.modelLinkerConfig['dropout'])
dim_feedforward = int(Configuration.modelLinkerConfig['dim_feedforward'])
layer_norm_eps = float(Configuration.modelLinkerConfig['layer_norm_eps'])
# layers
self.dropout = Dropout(dropout)
self.self_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout, batch_first=True,
kdim=dim_decoder, vdim=dim_decoder)
self.norm1 = LayerNorm(dim_decoder, eps=layer_norm_eps)
self.multihead_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout,
kdim=dim_encoder, vdim=dim_encoder,
batch_first=True)
self.norm2 = LayerNorm(dim_decoder, eps=layer_norm_eps)
self.ffn = FFN(d_model=dim_decoder, d_ff=dim_feedforward, dropout=dropout)
self.norm3 = LayerNorm(dim_decoder, eps=layer_norm_eps)
def forward(self, atoms_embedding: Tensor, sents_embedding: Tensor, encoder_mask: Tensor,
decoder_mask: Tensor) -> Tensor:
r"""Pass the inputs through the decoder layer.
Args:
atoms_embedding: the sequence to the decoder layer (required).
sents_embedding: the sequence from the last layer of the encoder (required)
encoder_mask
decoder_mask
"""
x = atoms_embedding
x = self.norm1(x + self._mask_mha_block(atoms_embedding, decoder_mask))
x = self.norm2(x + self._mha_block(x, sents_embedding, encoder_mask))
x = self.norm3(x + self._ff_block(x))
return x
# self-attention block
def _mask_mha_block(self, x: Tensor, decoder_mask: Tensor) -> Tensor:
x = self.self_attn(x, x, x, attn_mask=decoder_mask)[0]
return x
# multihead attention block
def _mha_block(self, x: Tensor, sents_embs: Tensor, encoder_mask: Tensor) -> Tensor:
x = self.multihead_attn(x, sents_embs, sents_embs, attn_mask=encoder_mask)[0]
return x
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.ffn.forward(x)
return x
No preview for this file type
File added
No preview for this file type
...@@ -2,11 +2,29 @@ import re ...@@ -2,11 +2,29 @@ import re
import regex import regex
import numpy as np import numpy as np
import torch import torch
from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU, MultiheadAttention
from torch.nn import Module
from SuperTagger.Linker.AtomTokenizer import AtomTokenizer from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
from SuperTagger.Linker.atom_map import atom_map from SuperTagger.Linker.atom_map import atom_map
from SuperTagger.utils import pad_sequence from SuperTagger.utils import pad_sequence
class FFN(Module):
"Implements FFN equation."
def __init__(self, d_model, d_ff, dropout=0.1):
super(FFN, self).__init__()
self.ffn = Sequential(
Linear(d_model, d_ff, bias=False),
GELU(),
Dropout(dropout),
Linear(d_ff, d_model, bias=False)
)
def forward(self, x):
return self.ffn(x)
regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)' regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
...@@ -29,9 +47,10 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): ...@@ -29,9 +47,10 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
for atom_type in list(atom_map.keys())[:-1]: for atom_type in list(atom_map.keys())[:-1]:
# filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
l_polarity_plus = [[x for i, x in enumerate(atoms_batch[s_idx]) if atoms_polarity[s_idx, i] l_polarity_plus = [[x for i, x in enumerate(atoms_batch[s_idx]) if atoms_polarity[s_idx, i]
and bool(re.search(atom_type+"_", atoms_batch[s_idx][i]))] for s_idx in range(len(atoms_batch))] and bool(re.search(atom_type + "_", atoms_batch[s_idx][i]))] for s_idx in
range(len(atoms_batch))]
l_polarity_minus = [[x for i, x in enumerate(atoms_batch[s_idx]) if not atoms_polarity[s_idx, i] l_polarity_minus = [[x for i, x in enumerate(atoms_batch[s_idx]) if not atoms_polarity[s_idx, i]
and bool(re.search(atom_type+"_", atoms_batch[s_idx][i]))] for s_idx in and bool(re.search(atom_type + "_", atoms_batch[s_idx][i]))] for s_idx in
range(len(atoms_batch))] range(len(atoms_batch))]
linking_plus_to_minus = pad_sequence( linking_plus_to_minus = pad_sequence(
......
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment