Skip to content
Snippets Groups Projects
Commit 78ec3dc4 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

starting train

parent a702fd51
Branches
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -6,6 +6,7 @@ symbols_vocab_size=26 ...@@ -6,6 +6,7 @@ symbols_vocab_size=26
atom_vocab_size=12 atom_vocab_size=12
max_len_sentence=148 max_len_sentence=148
max_symbols_in_sentence=1250 max_symbols_in_sentence=1250
max_atoms_in_one_type=50
[MODEL_ENCODER] [MODEL_ENCODER]
dim_encoder = 768 dim_encoder = 768
......
from torch import Tensor
from torch.nn import (GELU, Dropout, LayerNorm, Linear, Module, MultiheadAttention,
Sequential)
from Configuration import Configuration
from SuperTagger.Symbol.SymbolEmbedding import SymbolEmbedding
class FullyConnectedFeedForward(Module):
"Implements FFN equation."
def __init__(self, d_model, d_ff, dropout=0.1):
super(FullyConnectedFeedForward, self).__init__()
self.ffn = Sequential(
Linear(d_model, d_ff, bias=False),
GELU(),
Dropout(dropout),
Linear(d_ff, d_model, bias=False)
)
def forward(self, x):
return self.ffn(x)
class AttentionDecoderLayer(Module):
r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
This standard decoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
dim_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False``.
norm_first: if ``True``, layer norm is done prior to self attention, multihead
attention and feedforward operations, respectivaly. Otherwise it's done after.
Default: ``False`` (after).
"""
__constants__ = ['batch_first', 'norm_first']
def __init__(self) -> None:
factory_kwargs = {'device': Configuration.modelDecoderConfig['device'],
'dtype': Configuration.modelDecoderConfig['dtype']}
super(AttentionDecoderLayer, self).__init__()
# init params
dim_encoder = int(Configuration.modelDecoderConfig['dim_encoder'])
dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
self.max_symbols_in_sentence = int(Configuration.modelDecoderConfig['max_symbols_in_sentence'])
nhead = int(Configuration.modelDecoderConfig['nhead'])
dropout = float(Configuration.modelDecoderConfig['dropout'])
dim_feedforward = int(Configuration.modelDecoderConfig['dim_feedforward'])
layer_norm_eps = float(Configuration.modelDecoderConfig['layer_norm_eps'])
self.nhead = int(Configuration.modelDecoderConfig['nhead'])
self.symbols_embedder = SymbolEmbedding(self.dim_decoder, self.symbols_vocab_size)
# layers
self.dropout = Dropout(dropout)
self.self_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout, batch_first=True,
kdim=dim_decoder, vdim=dim_decoder)
self.norm1 = LayerNorm(dim_decoder, eps=layer_norm_eps)
self.multihead_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout,
kdim=dim_encoder, vdim=dim_encoder,
batch_first=True)
self.norm2 = LayerNorm(dim_decoder, eps=layer_norm_eps)
self.ffn = FullyConnectedFeedForward(d_model=dim_decoder, d_ff=dim_feedforward, dropout=dropout)
self.norm3 = LayerNorm(dim_decoder, eps=layer_norm_eps)
def forward(self, symbols_tokens: Tensor, sents_embedding: Tensor, encoder_mask: Tensor,
decoder_mask: Tensor) -> Tensor:
r"""Pass the inputs through the decoder layer.
Args:
symbols: the sequence to the decoder layer (required).
sents: the sequence from the last layer of the encoder (required).
"""
x = symbols_tokens
x = self.symbols_embedder(symbols_tokens)
x = self.norm1(x + self._mask_mha_block(x, decoder_mask))
x = self.norm2(x + self._mha_block(x, sents_embedding, encoder_mask))
x = self.norm3(x + self._ff_block(x))
return x
# self-attention block
def _mask_mha_block(self, x: Tensor, decoder_mask: Tensor) -> Tensor:
if decoder_mask is not None:
# Same mask applied to all h heads.
decoder_mask = decoder_mask.repeat(self.nhead, 1, 1)
x = self.self_attn(x, x, x, attn_mask=decoder_mask)[0]
return x
# multihead attention block
def _mha_block(self, x: Tensor, sents_embs: Tensor, encoder_mask: Tensor) -> Tensor:
if encoder_mask is not None:
# Same mask applied to all h heads.
encoder_mask = encoder_mask.repeat(self.nhead, 1, 1)
x = self.multihead_attn(x, sents_embs, sents_embs, attn_mask=encoder_mask)[0]
return x
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.ffn.forward(x)
return x
...@@ -41,6 +41,7 @@ class Linker(Module): ...@@ -41,6 +41,7 @@ class Linker(Module):
self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms']) self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms'])
self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters']) self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
self.dropout = Dropout(0.1) self.dropout = Dropout(0.1)
...@@ -100,11 +101,11 @@ class Linker(Module): ...@@ -100,11 +101,11 @@ class Linker(Module):
# to do select with list of list # to do select with list of list
pos_encoding = pad_sequence( pos_encoding = pad_sequence(
[atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence)) [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=self.max_atoms_in_one_type//2,
padding_value=0) padding_value=0)
neg_encoding = pad_sequence( neg_encoding = pad_sequence(
[atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence)) [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=self.max_atoms_in_one_type//2,
padding_value=0) padding_value=0)
# pos_encoding = self.pos_transformation(pos_encoding) # pos_encoding = self.pos_transformation(pos_encoding)
......
...@@ -4,7 +4,7 @@ from SuperTagger.Linker.AtomTokenizer import AtomTokenizer ...@@ -4,7 +4,7 @@ from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
from SuperTagger.Linker.atom_map import atom_map from SuperTagger.Linker.atom_map import atom_map
def category_to_atoms(category, category_to_atoms): def category_to_atoms(category, categories_to_atoms):
res = [i for i in atom_map.keys() if category in i] res = [i for i in atom_map.keys() if category in i]
if len(res) > 0: if len(res) > 0:
return [category] return [category]
...@@ -12,19 +12,19 @@ def category_to_atoms(category, category_to_atoms): ...@@ -12,19 +12,19 @@ def category_to_atoms(category, category_to_atoms):
category_cut = re.search(r'\w*\(\d+,(.+),(.+)\)', category) category_cut = re.search(r'\w*\(\d+,(.+),(.+)\)', category)
left_side, right_side = category_cut.group(1), category_cut.group(2) left_side, right_side = category_cut.group(1), category_cut.group(2)
category_to_atoms += category_to_atoms(left_side, []) categories_to_atoms += category_to_atoms(left_side, [])
category_to_atoms += category_to_atoms(right_side, []) categories_to_atoms += category_to_atoms(right_side, [])
return category_to_atoms return categories_to_atoms
def get_atoms_batch(category_batch): def get_atoms_batch(category_batch):
batch = [] batch = []
for sentence in category_batch: for sentence in category_batch:
category_to_atoms = [] categories_to_atoms = []
for category in sentence: for category in sentence:
category_to_atoms = category_to_atoms(category, category_to_atoms) categories_to_atoms = category_to_atoms(category, categories_to_atoms)
batch.append(category_to_atoms) batch.append(categories_to_atoms)
return batch return batch
......
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from torch.nn.functional import nll_loss, cross_entropy from torch.nn.functional import nll_loss, cross_entropy
from SuperTagger.Linker.atom_map import atom_map
from SuperTagger.Linker.utils import get_atoms_batch, find_pos_neg_idexes from SuperTagger.Linker.utils import get_atoms_batch, find_pos_neg_idexes
...@@ -26,12 +27,20 @@ def mesure_accuracy(batch_axiom_links, axiom_links_pred): ...@@ -26,12 +27,20 @@ def mesure_accuracy(batch_axiom_links, axiom_links_pred):
# then convert into atom_vocab_size lists of (batch_size, max atom in one cat) with prefix parcours of graphe # then convert into atom_vocab_size lists of (batch_size, max atom in one cat) with prefix parcours of graphe
atoms_polarity = find_pos_neg_idexes(atoms_batch) atoms_polarity = find_pos_neg_idexes(atoms_batch)
axiom_links_true = "" num_correct_links = 0
for atom_type in atom_map.keys():
#filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
# match axiom_links_pred and true data # contruire liste + et liste -
correct_links = torch.ones(axiom_links_pred.size()) # associer par indice
correct_links[axiom_links_pred != axiom_links_true] = 0
num_correct_links = correct_links.sum().item() axiom_links_true = ""
# match axiom_links_pred and true data
correct_links = torch.ones(axiom_links_pred.size())
correct_links[axiom_links_pred != axiom_links_true] = 0
num_correct_links += correct_links.sum().item()
return num_correct_links/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1]) return num_correct_links/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1])
...@@ -30,25 +30,29 @@ atoms_polarity = [[False, True, True, False, False, True, True, False], ...@@ -30,25 +30,29 @@ atoms_polarity = [[False, True, True, False, False, True, True, False],
atoms_encoding = torch.randn((2, 8, 24)) atoms_encoding = torch.randn((2, 8, 24))
matches = [] link_weights=[]
for atom_type in ["np", "v"]: for atom_type in ["np", "v"]:
pos_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if pos_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))] x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
print(pos_idx_per_atom_type)
neg_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if neg_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
not x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))] not x and atoms_batch[s_idx][i] == atom_type] for s_idx in
range(len(atoms_polarity))]
# to do select with list of list # to do select with list of list
pos_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence)) pos_encoding = pad_sequence(
for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=3, padding_value=0) [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
neg_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence)) for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=3,
for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=3, padding_value=0) padding_value=0)
neg_encoding = pad_sequence(
print(neg_encoding.shape) [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=3,
padding_value=0)
# pos_encoding = self.pos_transformation(pos_encoding)
# neg_encoding = self.neg_transformation(neg_encoding)
weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
print(weights.shape) link_weights.append(sinkhorn(weights, iters=3))
print("sinkhorn")
print(sinkhorn(weights, iters=3).shape)
matches.append(sinkhorn(weights, iters=3))
print(matches) print(torch.cat([link_weights[i].unsqueeze(0) for i in range(len(link_weights))]).shape)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment