Skip to content
Snippets Groups Projects
Commit c44879ab authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

starting train

parent b54804c0
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
from torch import Tensor
import torch
from torch.nn import (GELU, Dropout, LayerNorm, Linear, Module, MultiheadAttention,
Sequential)
from Configuration import Configuration
from SuperTagger.Symbol.SymbolEmbedding import SymbolEmbedding
class FFN(Module):
"Implements FFN equation."
def __init__(self, d_model, d_ff, dropout=0.1):
super(FFN, self).__init__()
self.ffn = Sequential(
Linear(d_model, d_ff, bias=False),
GELU(),
Dropout(dropout),
Linear(d_ff, d_model, bias=False)
)
def forward(self, x):
return self.ffn(x)
class AttentionLayer(Module):
r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
This standard decoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
dim_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False``.
norm_first: if ``True``, layer norm is done prior to self attention, multihead
attention and feedforward operations, respectivaly. Otherwise it's done after.
Default: ``False`` (after).
"""
__constants__ = ['batch_first', 'norm_first']
def __init__(self) -> None:
super(AttentionLayer, self).__init__()
# init params
dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms'])
dim_feedforward = int(Configuration.modelLinkerConfig['dim_feedforward'])
dropout = float(Configuration.modelLinkerConfig['dropout'])
layer_norm_eps = float(Configuration.modelLinkerConfig['layer_norm_eps'])
self.nhead = int(Configuration.modelLinkerConfig['nhead'])
self.max_symbols_in_sentence = int(Configuration.datasetConfig['max_symbols_in_sentence'])
self.symbols_embedder = SymbolEmbedding(self.dim_embedding_atoms, self.symbols_vocab_size)
# layers
self.dropout = Dropout(dropout)
self.self_attn = MultiheadAttention(dim_embedding_atoms, self.nhead, dropout=dropout, batch_first=True,
kdim=dim_embedding_atoms, vdim=dim_embedding_atoms)
self.norm1 = LayerNorm(dim_embedding_atoms, eps=layer_norm_eps)
self.multihead_attn = MultiheadAttention(dim_embedding_atoms, self.nhead, dropout=dropout,
kdim=dim_encoder, vdim=dim_encoder,
batch_first=True)
self.norm2 = LayerNorm(dim_embedding_atoms, eps=layer_norm_eps)
self.ffn = FFN(d_model=dim_embedding_atoms, d_ff=dim_feedforward, dropout=dropout)
self.norm3 = LayerNorm(dim_embedding_atoms, eps=layer_norm_eps)
def forward(self, atoms_embeddings, sents_embedding, encoder_mask, decoder_mask):
r"""Pass the inputs through the decoder layer.
Args:
atoms: the sequence to the decoder layer (required).
sents: the sequence from the last layer of the encoder (required).
"""
x = atoms_embeddings
x = self.norm1(x + self._mask_mha_block(x, decoder_mask))
x = self.norm2(x + self._mha_block(x, sents_embedding, encoder_mask))
x = self.norm3(x + self._ff_block(x))
return x
# self-attention block
def _mask_mha_block(self, x: Tensor, decoder_mask: Tensor) -> Tensor:
if decoder_mask is not None:
# Same mask applied to all h heads.
decoder_mask = decoder_mask.repeat(self.nhead, 1, 1)
x = self.self_attn(x, x, x, attn_mask=decoder_mask)[0]
return x
# multihead attention block
def _mha_block(self, x: Tensor, sents_embs: Tensor, encoder_mask: Tensor) -> Tensor:
if encoder_mask is not None:
# Same mask applied to all h heads.
encoder_mask = encoder_mask.repeat(self.nhead, 1, 1)
x = self.multihead_attn(x, sents_embs, sents_embs, attn_mask=encoder_mask)[0]
return x
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.ffn.forward(x)
return x
...@@ -3,17 +3,34 @@ from itertools import chain ...@@ -3,17 +3,34 @@ from itertools import chain
import torch import torch
from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU
from torch.nn import Module from torch.nn import Module
import torch.nn.functional as F
from Configuration import Configuration from Configuration import Configuration
from SuperTagger.Linker.AtomEmbedding import AtomEmbedding from SuperTagger.Linker.AtomEmbedding import AtomEmbedding
from SuperTagger.Linker.AtomTokenizer import AtomTokenizer from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
from SuperTagger.Linker.atom_map import atom_map from SuperTagger.Linker.atom_map import atom_map
from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch, mesure_accuracy from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch
from SuperTagger.Linker.AttentionLayer import FFN, AttentionLayer from SuperTagger.eval import mesure_accuracy
from SuperTagger.utils import pad_sequence from SuperTagger.utils import pad_sequence
class FFN(Module):
"Implements FFN equation."
def __init__(self, d_model, d_ff, dropout=0.1):
super(FFN, self).__init__()
self.ffn = Sequential(
Linear(d_model, d_ff, bias=False),
GELU(),
Dropout(dropout),
Linear(d_ff, d_model, bias=False)
)
def forward(self, x):
return self.ffn(x)
class Linker(Module): class Linker(Module):
def __init__(self): def __init__(self):
super(Linker, self).__init__() super(Linker, self).__init__()
...@@ -33,7 +50,7 @@ class Linker(Module): ...@@ -33,7 +50,7 @@ class Linker(Module):
self.atom_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id) self.atom_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id)
# to do : definit un encoding # to do : definit un encoding
self.linker_encoder = AttentionLayer() # self.linker_encoder =
self.pos_transformation = Sequential( self.pos_transformation = Sequential(
FFN(self.dim_polarity_transfo, self.dim_polarity_transfo, 0.1), FFN(self.dim_polarity_transfo, self.dim_polarity_transfo, 0.1),
...@@ -49,7 +66,7 @@ class Linker(Module): ...@@ -49,7 +66,7 @@ class Linker(Module):
decoder_attn_mask[atoms_batch.eq(self.padding_id)] = 0.0 decoder_attn_mask[atoms_batch.eq(self.padding_id)] = 0.0
return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_batch.shape[1], 1) return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_batch.shape[1], 1)
def forward(self, category_batch, sents_embedding, sents_mask): def forward(self, category_batch, sents_embedding):
''' '''
Parameters : Parameters :
category_batch : batch of size (batch_size, sequence_length) = output of decoder category_batch : batch of size (batch_size, sequence_length) = output of decoder
...@@ -96,26 +113,26 @@ class Linker(Module): ...@@ -96,26 +113,26 @@ class Linker(Module):
weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
link_weights.append(sinkhorn(weights, iters=3)) link_weights.append(sinkhorn(weights, iters=3))
return link_weights return torch.cat([link_weights[i].unsqueeze(0) for i in range(len(link_weights))])
def predict_axiom_links(self, b_sents_tokenized, b_sents_mask): def eval_batch(self, supertagger, batch, cross_entropy_loss):
return None batch_categories = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
batch_sentences = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
batch_axiom_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
def eval_batch(self, batch, cross_entropy_loss): batch_sentences_embedding = supertagger(batch_sentences, batch_sentences)
b_sents_tokenized = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
b_sents_mask = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
b_category = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
logits_axiom_links_pred = self.predict_axiom_links(b_sents_tokenized, b_sents_mask) logits_axiom_links_pred = self.forward(batch_categories, batch_sentences_embedding)
# Softmax and argmax # Softmax and argmax
axiom_links_pred = torch.argmax(torch.nn.functional.softmax(logits_axiom_links_pred, dim=2), dim=2)
accuracy = mesure_accuracy(b_category, axiom_links_pred) axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=2), dim=2)
loss = float(cross_entropy_loss(axiom_links_pred, b_category))
accuracy = mesure_accuracy(batch_axiom_links, axiom_links_pred)
loss = float(cross_entropy_loss(axiom_links_pred, batch_axiom_links))
return accuracy, loss return accuracy, loss
def eval_epoch(self, dataloader, cross_entropy_loss): def eval_epoch(self, supertagger, dataloader, cross_entropy_loss):
r"""Average the evaluation of all the batch. r"""Average the evaluation of all the batch.
Args: Args:
...@@ -126,7 +143,7 @@ class Linker(Module): ...@@ -126,7 +143,7 @@ class Linker(Module):
compt = 0 compt = 0
for step, batch in enumerate(dataloader): for step, batch in enumerate(dataloader):
compt += 1 compt += 1
accuracy, loss = self.eval_batch(batch, cross_entropy_loss) accuracy, loss = self.eval_batch(supertagger, batch, cross_entropy_loss)
accuracy_average += accuracy accuracy_average += accuracy
loss_average += loss loss_average += loss
......
...@@ -94,8 +94,3 @@ def find_pos_neg_idexes(batch_symbols): ...@@ -94,8 +94,3 @@ def find_pos_neg_idexes(batch_symbols):
return list_batch return list_batch
def mesure_accuracy(b_category, axiom_links_pred):
# Convert b_category into
return 0
\ No newline at end of file
...@@ -11,3 +11,23 @@ class SinkhornLoss(Module): ...@@ -11,3 +11,23 @@ class SinkhornLoss(Module):
def forward(self, predictions, truths): def forward(self, predictions, truths):
return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean') return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean')
for link, perm in zip(predictions, truths)) for link, perm in zip(predictions, truths))
def mesure_accuracy(batch_axiom_links, axiom_links_pred):
r"""
batch_axiom_links : (batch_size, ...)
axiom_links_pred : (batch_size, max_atoms_type_polarity)
"""
# Convert batch_axiom_links into list of atoms (batch_size, max_atoms_in_sentence)
# then convert into atom_vocab_size lists of (batch_size, max atom in one cat) with prefix parcours of graphe
axiom_links_true = ""
# match axiom_links_pred and true data
correct_links = torch.ones(axiom_links_pred.size())
correct_links[axiom_links_pred != axiom_links_true] = 0
num_correct_links = correct_links.sum().item()
return num_correct_links
\ No newline at end of file
...@@ -127,13 +127,13 @@ def run_epochs(epochs): ...@@ -127,13 +127,13 @@ def run_epochs(epochs):
optimizer_linker.zero_grad() optimizer_linker.zero_grad()
# Find the prediction of categories to feed the linker and the sentences embedding # Find the prediction of categories to feed the linker and the sentences embedding
category_logits_pred, sents_embedding, sents_mask = supertagger(batch_categories, batch_sentences) category_logits_pred, sents_embedding = supertagger(batch_categories, batch_sentences)
# Predict the categories from prediction with argmax and softmax # Predict the categories from prediction with argmax and softmax
category_batch = torch.argmax(torch.nn.functional.softmax(category_logits_pred, dim=2), dim=2) category_batch = torch.argmax(torch.nn.functional.softmax(category_logits_pred, dim=2), dim=2)
# Run the kinker on the categories predictions # Run the kinker on the categories predictions
logits_predictions = linker(category_batch, sents_embedding, sents_mask) logits_predictions = linker(category_batch, sents_embedding)
linker_loss = cross_entropy_loss(logits_predictions, batch_axiom_links) linker_loss = cross_entropy_loss(logits_predictions, batch_axiom_links)
# Perform a backward pass to calculate the gradients. # Perform a backward pass to calculate the gradients.
...@@ -145,7 +145,6 @@ def run_epochs(epochs): ...@@ -145,7 +145,6 @@ def run_epochs(epochs):
# Update parameters and take a step using the computed gradient. # Update parameters and take a step using the computed gradient.
optimizer_linker.step() optimizer_linker.step()
scheduler_linker.step() scheduler_linker.step()
avg_train_loss = total_train_loss / len(training_dataloader) avg_train_loss = total_train_loss / len(training_dataloader)
...@@ -157,7 +156,7 @@ def run_epochs(epochs): ...@@ -157,7 +156,7 @@ def run_epochs(epochs):
linker.eval() linker.eval()
with torch.no_grad(): with torch.no_grad():
print("Start eval") print("Start eval")
accuracy_sents, accuracy_atom, v_loss = linker.eval_epoch(validation_dataloader, cross_entropy_loss) accuracy_sents, accuracy_atom, v_loss = linker.eval_epoch(supertagger, validation_dataloader, cross_entropy_loss)
print("") print("")
print(" Average accuracy sents on epoch: {0:.2f}".format(accuracy_sents)) print(" Average accuracy sents on epoch: {0:.2f}".format(accuracy_sents))
print(" Average accuracy atom on epoch: {0:.2f}".format(accuracy_atom)) print(" Average accuracy atom on epoch: {0:.2f}".format(accuracy_atom))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment