Skip to content
Snippets Groups Projects
Commit cea987d4 authored by Caroline de Pourtalès's avatar Caroline de Pourtalès :speech_balloon:
Browse files

Merge branch 'change-preprocess' into 'version-linker'

Change preprocess

See merge request !2
parents f028d834 2889028a
No related branches found
No related tags found
3 merge requests!6Linker with transformer,!5Linker with transformer,!2Change preprocess
......@@ -12,7 +12,7 @@ max_atoms_in_one_type=510
dim_encoder = 768
[MODEL_DECODER]
nhead=8
nhead=4
num_layers=1
dropout=0.1
dim_feedforward=512
......@@ -26,7 +26,7 @@ dropout=0.1
sinkhorn_iters=3
[MODEL_TRAINING]
batch_size=16
batch_size=32
epoch=30
seed_val=42
learning_rate=2e-4
import torch
from torch.nn import Module, Embedding
class AtomEmbedding(Module):
def __init__(self, dim_linker, atom_vocab_size, padding_idx=None):
super(AtomEmbedding, self).__init__()
self.emb = Embedding(num_embeddings=atom_vocab_size, embedding_dim=dim_linker, padding_idx=padding_idx,
scale_grad_by_freq=True)
def forward(self, x):
return self.emb(x)
......@@ -6,22 +6,20 @@ import datetime
import time
import torch.nn.functional as F
from torch.nn import Sequential, LayerNorm, Dropout
from torch.nn import Sequential, LayerNorm, Dropout, Embedding
from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from Configuration import Configuration
from Linker.AtomEmbedding import AtomEmbedding
from Linker.AtomTokenizer import AtomTokenizer
from Linker.MHA import AttentionDecoderLayer
from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from Linker.atom_map import atom_map
from Linker.atom_map import atom_map, atom_map_redux
from Linker.eval import mesure_accuracy, SinkhornLoss
from Linker.utils_linker import FFN, get_axiom_links, get_GOAL
from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx
from Supertagger import *
from utils import pad_sequence
def format_time(elapsed):
......@@ -62,18 +60,20 @@ class Linker(Module):
atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
self.dropout = Dropout(0.1)
self.device = "cpu"
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
supertagger = SuperTagger()
supertagger.load_weights(supertagger_path_model)
self.Supertagger = supertagger
self.atom_map = atom_map
self.sub_atoms_type_list = ['cl_r', 'pp', 'n', 'np', 'cl_y', 'txt', 's']
self.sub_atoms_type_list = list(atom_map_redux.keys())
self.padding_id = self.atom_map['[PAD]']
self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
self.inverse_map = self.atoms_tokenizer.inverse_atom_map
self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, atom_vocab_size, self.padding_id)
self.atoms_embedding = Embedding(num_embeddings=atom_vocab_size, embedding_dim=self.dim_embedding_atoms,
padding_idx=self.padding_id,
scale_grad_by_freq=True)
self.linker_encoder = AttentionDecoderLayer()
......@@ -90,8 +90,6 @@ class Linker(Module):
self.optimizer = AdamW(self.parameters(),
lr=learning_rate)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
......@@ -110,12 +108,15 @@ class Linker(Module):
atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"])
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
pos_idx = get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, self.max_atoms_in_one_type)
neg_idx = get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, self.max_atoms_in_one_type)
truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, self.sub_atoms_type_list, atoms_polarity_batch,
df_axiom_links["Y"])
truth_links_batch = truth_links_batch.permute(1, 0, 2)
# Construction tensor dataset
dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch, sentences_tokens,
dataset = TensorDataset(atoms_batch_tokenized, pos_idx, neg_idx, truth_links_batch, sentences_tokens,
sentences_mask)
if validation_rate > 0.0:
......@@ -136,11 +137,12 @@ class Linker(Module):
decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0
return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1)
def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None):
def forward(self, atoms_batch_tokenized, batch_pos_idx, batch_neg_idx, sents_embedding, sents_mask=None):
r"""
Args:
atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories
atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities
batch_pos_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities
batch_neg_idx :
sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
sents_mask : mask from BERT tokenizer
Returns:
......@@ -157,13 +159,8 @@ class Linker(Module):
link_weights = []
for atom_type in self.sub_atoms_type_list:
pos_encoding = torch.stack([self.get_pos_encoding_for_s_idx(atoms_encoding, atoms_batch_tokenized,
atoms_polarity_batch, atom_type, s_idx)
for s_idx in range(len(atoms_polarity_batch))])
neg_encoding = torch.stack([self.get_neg_encoding_for_s_idx(atoms_encoding, atoms_batch_tokenized,
atoms_polarity_batch, atom_type, s_idx)
for s_idx in range(len(atoms_polarity_batch))])
pos_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_pos_idx, atom_type)
neg_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_neg_idx, atom_type)
pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding)
......@@ -172,9 +169,8 @@ class Linker(Module):
link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters))
total_link_weights = torch.stack(link_weights)
link_weights_per_batch = total_link_weights.permute(1, 0, 2, 3)
return F.log_softmax(link_weights_per_batch, dim=3)
return F.log_softmax(total_link_weights, dim=3)
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, tensorboard=False):
......@@ -246,10 +242,11 @@ class Linker(Module):
for batch in tepoch:
# Unpack this training batch from our dataloader
batch_atoms = batch[0].to(self.device)
batch_polarity = batch[1].to(self.device)
batch_true_links = batch[2].to(self.device)
batch_sentences_tokens = batch[3].to(self.device)
batch_sentences_mask = batch[4].to(self.device)
batch_pos_idx = batch[1].to(self.device)
batch_neg_idx = batch[2].to(self.device)
batch_true_links = batch[3].to(self.device)
batch_sentences_tokens = batch[4].to(self.device)
batch_sentences_mask = batch[5].to(self.device)
self.optimizer.zero_grad()
......@@ -257,7 +254,8 @@ class Linker(Module):
logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
# Run the kinker on the categories predictions
logits_predictions = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask)
logits_predictions = self(batch_atoms, batch_pos_idx, batch_neg_idx, sentences_embedding,
batch_sentences_mask)
linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
# Perform a backward pass to calculate the gradients.
......@@ -280,67 +278,24 @@ class Linker(Module):
return avg_train_loss, avg_accuracy_train, training_time
def predict(self, categories, sents_embedding, sents_mask=None):
r"""Prediction from categories output by BERT and hidden_state from BERT
Args:
categories : (batch_size, len_sentence)
sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
sents_mask
Returns:
axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat)
"""
self.eval()
with torch.no_grad():
# get atoms
atoms_batch, polarities = get_GOAL(self.max_atoms_in_sentence, categories)
atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
# atoms embedding
atoms_embedding = self.atoms_embedding(atoms_tokenized)
# MHA ou LSTM avec sortie de BERT
atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask,
self.make_decoder_mask(atoms_tokenized))
link_weights = []
for atom_type in self.sub_atoms_type_list:
pos_encoding = pad_sequence(
[self.get_pos_encoding_for_s_idx(atoms_encoding, atoms_tokenized, polarities, atom_type, s_idx)
for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2)
neg_encoding = pad_sequence(
[self.get_neg_encoding_for_s_idx(atoms_encoding, atoms_tokenized, polarities, atom_type, s_idx)
for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2)
pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding)
weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
link_weights.append(sinkhorn(weights, iters=3))
logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3)
axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3)
return axiom_links
def eval_batch(self, batch):
batch_atoms = batch[0].to(self.device)
batch_polarity = batch[1].to(self.device)
batch_true_links = batch[2].to(self.device)
batch_sentences_tokens = batch[3].to(self.device)
batch_sentences_mask = batch[4].to(self.device)
batch_pos_idx = batch[1].to(self.device)
batch_neg_idx = batch[2].to(self.device)
batch_true_links = batch[3].to(self.device)
batch_sentences_tokens = batch[4].to(self.device)
batch_sentences_mask = batch[5].to(self.device)
logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding,
logits_axiom_links_pred = self(batch_atoms, batch_pos_idx, batch_neg_idx, sentences_embedding,
batch_sentences_mask)
axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3)
print('\n')
print("Tokens de la phrase : ", batch_sentences_tokens[1])
print("Atoms dans la phrase : ", (batch_atoms[1][:50]))
print("Polarités des atoms de la phrase : ", batch_polarity[1][:50])
print("Polarités + des atoms de la phrase : ", batch_pos_idx[1][:50])
print("Polarités - des atoms de la phrase : ", batch_neg_idx[1][:50])
print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100])
print("Les prédictions : ", axiom_links_pred[1][2][:100])
print('\n')
......@@ -402,34 +357,18 @@ class Linker(Module):
}, path)
self.to(self.device)
def get_pos_encoding_for_s_idx(self, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx):
pos_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
bool(re.match(r"" + atom_type + "_?\w*",
self.inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and
atoms_polarity_batch[s_idx][i])]
if len(pos_encoding) == 0:
return torch.zeros(self.max_atoms_in_one_type//2, self.dim_embedding_atoms,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
else:
len_pos_encoding = len(pos_encoding)
pos_encoding += [torch.zeros(self.dim_embedding_atoms,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) for i in
range(self.max_atoms_in_one_type//2 - len_pos_encoding)]
return torch.stack(pos_encoding)
def get_neg_encoding_for_s_idx(self, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx):
neg_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
bool(re.match(r"" + atom_type + "_?\w*",
self.inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and
not atoms_polarity_batch[s_idx][i])]
if len(neg_encoding) == 0:
return torch.zeros(self.max_atoms_in_one_type//2, self.dim_embedding_atoms,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
else:
len_neg_encoding = len(neg_encoding)
neg_encoding += [torch.zeros(self.dim_embedding_atoms,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) for i in
range(self.max_atoms_in_one_type//2 - len_neg_encoding)]
return torch.stack(neg_encoding)
def make_sinkhorn_inputs(self, bsd_tensor, positional_ids, atom_type):
"""
:param bsd_tensor:
Tensor of shape batch size \times sequence length \times feature dimensionality.
:param positional_ids:
A List of batch_size elements, each being a List of num_atoms LongTensors.
Each LongTensor in positional_ids[b][a] indexes the location of atoms of type a in sentence b.
:param atom_type:
:return:
"""
return torch.stack([torch.stack([bsd_tensor.select(0, index=i).select(0, index=int(atom)).to(self.device)
if atom != -1 else torch.zeros(self.dim_embedding_atoms, device=self.device)
for atom in sentence])
for i, sentence in enumerate(positional_ids[:, atom_map_redux[atom_type], :])])
from .Linker import Linker
from .atom_map import atom_map
from .AtomEmbedding import AtomEmbedding
from .AtomTokenizer import AtomTokenizer
\ No newline at end of file
......@@ -17,3 +17,13 @@ atom_map = \
's_ppart': 15,
'[PAD]': 16
}
atom_map_redux = {
'cl_r': 0,
'pp': 1,
'n': 2,
'np': 3,
'cl_y': 4,
'txt': 5,
's': 6
}
......@@ -9,14 +9,15 @@ class SinkhornLoss(Module):
def forward(self, predictions, truths):
return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1)
for link, perm in zip(predictions, truths))
for link, perm in zip(predictions, truths.permute(1, 0, 2)))
def mesure_accuracy(batch_true_links, axiom_links_pred):
r"""
batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
axiom_links_pred : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
"""
batch_true_links=batch_true_links.permute(1, 0, 2)
correct_links = torch.ones(axiom_links_pred.size())
correct_links[axiom_links_pred != batch_true_links] = 0
correct_links[batch_true_links == -1] = 1
......
......@@ -3,7 +3,7 @@ import regex
import torch
from torch.nn import Sequential, Linear, Dropout, GELU
from torch.nn import Module
from Linker.atom_map import atom_map
from Linker.atom_map import atom_map, atom_map_redux
from utils import pad_sequence
......@@ -276,4 +276,19 @@ def get_GOAL(max_atoms_in_sentence, categories_batch):
################################ Prepare encoding ###############################################
#########################################################################################
def get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, max_atoms_in_one_type):
inverse_atom_map = {v: k for k, v in atom_map.items()}
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(re.match(r"" + atom_type + "_?\w*", inverse_atom_map[int(atoms_batch_tokenized[s_idx][i])])) and
atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch_tokenized)], max_len=max_atoms_in_one_type//2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())]
return torch.stack(pos_idx).permute(1, 0, 2)
def get_neg_idx(atoms_batch_tokenized, atoms_polarity_batch, max_atoms_in_one_type):
inverse_atom_map = {v: k for k, v in atom_map.items()}
neg_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(re.match(r"" + atom_type + "_?\w*", inverse_atom_map[int(atoms_batch_tokenized[s_idx][i])])) and
not atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch_tokenized)], max_len=max_atoms_in_one_type//2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())]
return torch.stack(neg_idx).permute(1, 0, 2)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment