Skip to content
Snippets Groups Projects
Commit 78ec3dc4 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

starting train

parent a702fd51
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
......@@ -6,6 +6,7 @@ symbols_vocab_size=26
atom_vocab_size=12
max_len_sentence=148
max_symbols_in_sentence=1250
max_atoms_in_one_type=50
[MODEL_ENCODER]
dim_encoder = 768
......
from torch import Tensor
from torch.nn import (GELU, Dropout, LayerNorm, Linear, Module, MultiheadAttention,
Sequential)
from Configuration import Configuration
from SuperTagger.Symbol.SymbolEmbedding import SymbolEmbedding
class FullyConnectedFeedForward(Module):
"Implements FFN equation."
def __init__(self, d_model, d_ff, dropout=0.1):
super(FullyConnectedFeedForward, self).__init__()
self.ffn = Sequential(
Linear(d_model, d_ff, bias=False),
GELU(),
Dropout(dropout),
Linear(d_ff, d_model, bias=False)
)
def forward(self, x):
return self.ffn(x)
class AttentionDecoderLayer(Module):
r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
This standard decoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
dim_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False``.
norm_first: if ``True``, layer norm is done prior to self attention, multihead
attention and feedforward operations, respectivaly. Otherwise it's done after.
Default: ``False`` (after).
"""
__constants__ = ['batch_first', 'norm_first']
def __init__(self) -> None:
factory_kwargs = {'device': Configuration.modelDecoderConfig['device'],
'dtype': Configuration.modelDecoderConfig['dtype']}
super(AttentionDecoderLayer, self).__init__()
# init params
dim_encoder = int(Configuration.modelDecoderConfig['dim_encoder'])
dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
self.max_symbols_in_sentence = int(Configuration.modelDecoderConfig['max_symbols_in_sentence'])
nhead = int(Configuration.modelDecoderConfig['nhead'])
dropout = float(Configuration.modelDecoderConfig['dropout'])
dim_feedforward = int(Configuration.modelDecoderConfig['dim_feedforward'])
layer_norm_eps = float(Configuration.modelDecoderConfig['layer_norm_eps'])
self.nhead = int(Configuration.modelDecoderConfig['nhead'])
self.symbols_embedder = SymbolEmbedding(self.dim_decoder, self.symbols_vocab_size)
# layers
self.dropout = Dropout(dropout)
self.self_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout, batch_first=True,
kdim=dim_decoder, vdim=dim_decoder)
self.norm1 = LayerNorm(dim_decoder, eps=layer_norm_eps)
self.multihead_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout,
kdim=dim_encoder, vdim=dim_encoder,
batch_first=True)
self.norm2 = LayerNorm(dim_decoder, eps=layer_norm_eps)
self.ffn = FullyConnectedFeedForward(d_model=dim_decoder, d_ff=dim_feedforward, dropout=dropout)
self.norm3 = LayerNorm(dim_decoder, eps=layer_norm_eps)
def forward(self, symbols_tokens: Tensor, sents_embedding: Tensor, encoder_mask: Tensor,
decoder_mask: Tensor) -> Tensor:
r"""Pass the inputs through the decoder layer.
Args:
symbols: the sequence to the decoder layer (required).
sents: the sequence from the last layer of the encoder (required).
"""
x = symbols_tokens
x = self.symbols_embedder(symbols_tokens)
x = self.norm1(x + self._mask_mha_block(x, decoder_mask))
x = self.norm2(x + self._mha_block(x, sents_embedding, encoder_mask))
x = self.norm3(x + self._ff_block(x))
return x
# self-attention block
def _mask_mha_block(self, x: Tensor, decoder_mask: Tensor) -> Tensor:
if decoder_mask is not None:
# Same mask applied to all h heads.
decoder_mask = decoder_mask.repeat(self.nhead, 1, 1)
x = self.self_attn(x, x, x, attn_mask=decoder_mask)[0]
return x
# multihead attention block
def _mha_block(self, x: Tensor, sents_embs: Tensor, encoder_mask: Tensor) -> Tensor:
if encoder_mask is not None:
# Same mask applied to all h heads.
encoder_mask = encoder_mask.repeat(self.nhead, 1, 1)
x = self.multihead_attn(x, sents_embs, sents_embs, attn_mask=encoder_mask)[0]
return x
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.ffn.forward(x)
return x
......@@ -41,6 +41,7 @@ class Linker(Module):
self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms'])
self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
self.dropout = Dropout(0.1)
......@@ -100,11 +101,11 @@ class Linker(Module):
# to do select with list of list
pos_encoding = pad_sequence(
[atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=self.max_atoms_in_sentence,
for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=self.max_atoms_in_one_type//2,
padding_value=0)
neg_encoding = pad_sequence(
[atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=self.max_atoms_in_sentence,
for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=self.max_atoms_in_one_type//2,
padding_value=0)
# pos_encoding = self.pos_transformation(pos_encoding)
......
......@@ -4,7 +4,7 @@ from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
from SuperTagger.Linker.atom_map import atom_map
def category_to_atoms(category, category_to_atoms):
def category_to_atoms(category, categories_to_atoms):
res = [i for i in atom_map.keys() if category in i]
if len(res) > 0:
return [category]
......@@ -12,19 +12,19 @@ def category_to_atoms(category, category_to_atoms):
category_cut = re.search(r'\w*\(\d+,(.+),(.+)\)', category)
left_side, right_side = category_cut.group(1), category_cut.group(2)
category_to_atoms += category_to_atoms(left_side, [])
category_to_atoms += category_to_atoms(right_side, [])
categories_to_atoms += category_to_atoms(left_side, [])
categories_to_atoms += category_to_atoms(right_side, [])
return category_to_atoms
return categories_to_atoms
def get_atoms_batch(category_batch):
batch = []
for sentence in category_batch:
category_to_atoms = []
categories_to_atoms = []
for category in sentence:
category_to_atoms = category_to_atoms(category, category_to_atoms)
batch.append(category_to_atoms)
categories_to_atoms = category_to_atoms(category, categories_to_atoms)
batch.append(categories_to_atoms)
return batch
......
......@@ -2,6 +2,7 @@ import torch
from torch import Tensor
from torch.nn import Module
from torch.nn.functional import nll_loss, cross_entropy
from SuperTagger.Linker.atom_map import atom_map
from SuperTagger.Linker.utils import get_atoms_batch, find_pos_neg_idexes
......@@ -26,12 +27,20 @@ def mesure_accuracy(batch_axiom_links, axiom_links_pred):
# then convert into atom_vocab_size lists of (batch_size, max atom in one cat) with prefix parcours of graphe
atoms_polarity = find_pos_neg_idexes(atoms_batch)
axiom_links_true = ""
num_correct_links = 0
for atom_type in atom_map.keys():
#filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
# match axiom_links_pred and true data
# contruire liste + et liste -
correct_links = torch.ones(axiom_links_pred.size())
correct_links[axiom_links_pred != axiom_links_true] = 0
num_correct_links = correct_links.sum().item()
# associer par indice
axiom_links_true = ""
# match axiom_links_pred and true data
correct_links = torch.ones(axiom_links_pred.size())
correct_links[axiom_links_pred != axiom_links_true] = 0
num_correct_links += correct_links.sum().item()
return num_correct_links/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1])
......@@ -30,25 +30,29 @@ atoms_polarity = [[False, True, True, False, False, True, True, False],
atoms_encoding = torch.randn((2, 8, 24))
matches = []
link_weights=[]
for atom_type in ["np", "v"]:
pos_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
print(pos_idx_per_atom_type)
neg_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
not x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
not x and atoms_batch[s_idx][i] == atom_type] for s_idx in
range(len(atoms_polarity))]
# to do select with list of list
pos_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=3, padding_value=0)
neg_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=3, padding_value=0)
print(neg_encoding.shape)
pos_encoding = pad_sequence(
[atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=3,
padding_value=0)
neg_encoding = pad_sequence(
[atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=3,
padding_value=0)
# pos_encoding = self.pos_transformation(pos_encoding)
# neg_encoding = self.neg_transformation(neg_encoding)
weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
print(weights.shape)
print("sinkhorn")
print(sinkhorn(weights, iters=3).shape)
matches.append(sinkhorn(weights, iters=3))
link_weights.append(sinkhorn(weights, iters=3))
print(matches)
print(torch.cat([link_weights[i].unsqueeze(0) for i in range(len(link_weights))]).shape)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment