Skip to content
Snippets Groups Projects
Commit 154eabc1 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

architecture and main

parent 85a4adab
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
Showing
with 250 additions and 23 deletions
File moved
import torch
from SuperTagger.utils import pad_sequence
from ..utils import pad_sequence
class AtomTokenizer(object):
......
from itertools import chain
import torch
from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU, MultiheadAttention
from torch.nn import Sequential, LayerNorm, Dropout
from torch.nn import Module
import torch.nn.functional as F
import sys
from Configuration import Configuration
from SuperTagger.Linker.AtomEmbedding import AtomEmbedding
from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
from SuperTagger.Linker.MHA import AttentionDecoderLayer
from SuperTagger.Linker.atom_map import atom_map
from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch, FFN
from SuperTagger.eval import mesure_accuracy
from SuperTagger.utils import pad_sequence
from AtomEmbedding import AtomEmbedding
from AtomTokenizer import AtomTokenizer
from MHA import AttentionDecoderLayer
from atom_map import atom_map
from Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN
from eval import mesure_accuracy
from ..utils import pad_sequence
class Linker(Module):
......@@ -30,11 +28,12 @@ class Linker(Module):
self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
self.dropout = Dropout(0.1)
self.device = ""
self.atom_map = atom_map
self.padding_id = self.atom_map['[PAD]']
self.atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
self.atom_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id)
self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id)
# to do : definit un encoding
self.linker_encoder = AttentionDecoderLayer()
......@@ -48,7 +47,7 @@ class Linker(Module):
LayerNorm(self.dim_embedding_atoms, eps=1e-12)
)
def make_decoder_mask(self, atoms_token) :
def make_decoder_mask(self, atoms_token):
decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64)
decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0
return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1)
......@@ -65,29 +64,34 @@ class Linker(Module):
'''
# atoms embedding
atoms_embedding = self.atom_embedding(atoms_batch_tokenized)
print(atoms_embedding.shape)
atoms_embedding = self.atoms_embedding(atoms_batch_tokenized)
# MHA ou LSTM avec sortie de BERT
sents_embedding = torch.randn(32, self.max_len_sentence, self.dim_encoder)
batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape
sents_mask = torch.randn(batch_size * self.nhead, self.max_atoms_in_sentence, self.max_len_sentence)
atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, self.make_decoder_mask(atoms_batch_tokenized))
#atoms_encoding = atoms_embedding
atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask,
self.make_decoder_mask(atoms_batch_tokenized))
link_weights = []
for atom_type in list(self.atom_map.keys())[:-1]:
pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
atoms_batch_tokenized[s_idx][i] == self.atom_map[atom_type] and
atoms_polarity_batch[s_idx][i])] + [torch.zeros(self.dim_embedding_atoms)])
for s_idx in range(len(atoms_polarity_batch))], padding_value=0, max_len=self.max_atoms_in_one_type//2)
neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
atoms_batch_tokenized[s_idx][i] == self.atom_map[atom_type] and
not atoms_polarity_batch[s_idx][i])] + [torch.zeros(self.dim_embedding_atoms)])
for s_idx in range(len(atoms_polarity_batch))], padding_value=0, max_len=self.max_atoms_in_one_type//2)
pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
atoms_batch_tokenized[s_idx][i] == self.atom_map[
atom_type] and
atoms_polarity_batch[s_idx][i])] + [
torch.zeros(self.dim_embedding_atoms)])
for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2)
neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
atoms_batch_tokenized[s_idx][i] == self.atom_map[
atom_type] and
not atoms_polarity_batch[s_idx][i])] + [
torch.zeros(self.dim_embedding_atoms)])
for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2)
pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding)
......@@ -97,11 +101,68 @@ class Linker(Module):
return torch.stack(link_weights)
def predict(self, categories, sents_embedding, sents_mask=None):
r'''
Parameters :
categories : (batch_size, len_sentence)
sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
sents_mask
Returns :
axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat)
'''
self.eval()
batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape
# get atoms
atoms_batch = get_atoms_batch(categories)
atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
# get polarities
polarities = find_pos_neg_idexes(self.max_atoms_in_sentence, categories)
# atoms embedding
atoms_embedding = self.atoms_embedding(atoms_tokenized)
# MHA ou LSTM avec sortie de BERT
atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask,
self.make_decoder_mask(atoms_tokenized))
link_weights = []
for atom_type in list(self.atom_map.keys())[:-1]:
pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and
atoms_tokenized[s_idx][i] == self.atom_map[
atom_type] and
polarities[s_idx][i])] + [
torch.zeros(self.dim_embedding_atoms)])
for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2)
neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and
atoms_tokenized[s_idx][i] == self.atom_map[
atom_type] and
not polarities[s_idx][i])] + [
torch.zeros(self.dim_embedding_atoms)])
for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2)
pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding)
weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
link_weights.append(sinkhorn(weights, iters=3))
logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3)
axiom_links = torch.argmax(F.softmax(logits_predictions, dim=3), dim=3)
return axiom_links
def eval_batch(self, batch, cross_entropy_loss):
batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
#batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
# batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
logits_axiom_links_pred = self.forward(batch_atoms, batch_polarity, [])
logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3)
......@@ -128,3 +189,33 @@ class Linker(Module):
loss_average += loss
return accuracy_average / compt, loss_average / compt
def load_weights(self, model_file):
print("#" * 15)
try:
params = torch.load(model_file, map_location=self.device)
args = params['args']
self.atom_map = args['atom_map']
self.max_atoms_in_sentence = args['max_atoms_in_sentence']
self.atoms_tokenizer = AtomTokenizer(self.atom_map, self.max_atoms_in_sentence)
self.atoms_embedding.load_state_dict(params['atoms_embedding'])
self.linker_encoder.load_state_dict(params['linker_encoder'])
self.pos_transformation.load_state_dict(params['pos_transformation'])
self.neg_transformation.load_state_dict(params['neg_transformation'])
print("\n The loading checkpoint was successful ! \n")
except Exception as e:
print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr)
raise e
print("#" * 15)
def __checkpoint_save(self, path='/linker.pt'):
self.linker.cpu()
torch.save({
'args': dict(atom_map=self.atom_map, max_atoms_in_sentence=self.max_atoms_in_sentence),
'atoms_embedding': self.atoms_embedding.state_dict(),
'linker_encoder': self.linker_encoder.state_dict(),
'pos_transformation': self.pos_transformation.state_dict(),
'neg_transformation': self.neg_transformation.state_dict()
}, path)
self.linker.to(self.device)
import copy
import torch
import torch.nn.functional as F
import torch.optim as optim
from Configuration import Configuration
from torch import Tensor, LongTensor
from torch.nn import (GELU, LSTM, Dropout, LayerNorm, Linear, Module, MultiheadAttention,
ModuleList, Sequential)
from torch import Tensor
from torch.nn import (Dropout, LayerNorm, Module, MultiheadAttention)
from SuperTagger.Linker.utils import FFN
from Configuration import Configuration
from utils_linker import FFN
class AttentionDecoderLayer(Module):
......
File moved
File moved
import torch
from torch import Tensor
from torch.nn import Module
from torch.nn.functional import nll_loss, cross_entropy
from SuperTagger.Linker.atom_map import atom_map
import re
from SuperTagger.Linker.utils import get_atoms_batch, find_pos_neg_idexes
from SuperTagger.utils import pad_sequence
from torch.nn.functional import nll_loss
class SinkhornLoss(Module):
......
import re
import regex
import numpy as np
import torch
from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU, MultiheadAttention
from torch.nn import Sequential, Linear, Dropout, GELU
from torch.nn import Module
from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
from SuperTagger.Linker.atom_map import atom_map
from SuperTagger.utils import pad_sequence
from atom_map import atom_map
from ..utils import pad_sequence
class FFN(Module):
......
File deleted
File deleted
File deleted
File deleted
File deleted
File deleted
File deleted
File deleted
File deleted
main.py 0 → 100644
import torch.nn.functional as F
from Configuration import Configuration
from Linker.Linker import Linker
max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
# categories tagger
tagger = SuperTagger()
tagger.load_weights("models/model_check.pt")
# axiom linker
linker = Linker()
linker.load_weights("models/linker.pt")
# predict categories and links for this sentence
sentence = [[]]
categories, sentence_embedding = tagger.predict(sentence)
axiom_links = linker.predict(categories, sentence_embedding)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment