Skip to content
Snippets Groups Projects
Commit 8d109c5a authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

added supertagger, running is ok, need to decde on MHA

parent 4e0e61ea
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -4,7 +4,7 @@ transformers = 4.16.2 ...@@ -4,7 +4,7 @@ transformers = 4.16.2
[DATASET_PARAMS] [DATASET_PARAMS]
symbols_vocab_size=26 symbols_vocab_size=26
atom_vocab_size=20 atom_vocab_size=20
max_len_sentence=148 max_len_sentence=109
max_atoms_in_sentence=1250 max_atoms_in_sentence=1250
max_atoms_in_one_type=250 max_atoms_in_one_type=250
......
import torch import torch
from ..utils import pad_sequence from utils import pad_sequence
class AtomTokenizer(object): class AtomTokenizer(object):
......
...@@ -12,18 +12,18 @@ from torch.utils.data import TensorDataset, random_split ...@@ -12,18 +12,18 @@ from torch.utils.data import TensorDataset, random_split
from transformers import get_cosine_schedule_with_warmup from transformers import get_cosine_schedule_with_warmup
from Configuration import Configuration from Configuration import Configuration
from AtomEmbedding import AtomEmbedding from Linker.AtomEmbedding import AtomEmbedding
from AtomTokenizer import AtomTokenizer from Linker.AtomTokenizer import AtomTokenizer
from MHA import AttentionDecoderLayer from Linker.MHA import AttentionDecoderLayer
from atom_map import atom_map from Linker.atom_map import atom_map
from Sinkhorn import sinkhorn_fn_no_exp as sinkhorn from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links
from eval import mesure_accuracy, SinkhornLoss from Linker.eval import mesure_accuracy, SinkhornLoss
from ..utils import pad_sequence from utils import pad_sequence
class Linker(Module): class Linker(Module):
def __init__(self): def __init__(self, supertagger):
super(Linker, self).__init__() super(Linker, self).__init__()
self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder']) self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
...@@ -39,6 +39,8 @@ class Linker(Module): ...@@ -39,6 +39,8 @@ class Linker(Module):
self.dropout = Dropout(0.1) self.dropout = Dropout(0.1)
self.device = "" self.device = ""
self.Supertagger = supertagger
self.atom_map = atom_map self.atom_map = atom_map
self.padding_id = self.atom_map['[PAD]'] self.padding_id = self.atom_map['[PAD]']
self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
...@@ -63,7 +65,7 @@ class Linker(Module): ...@@ -63,7 +65,7 @@ class Linker(Module):
num_warmup_steps=0, num_warmup_steps=0,
num_training_steps=100) num_training_steps=100)
def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.0): def __preprocess_data(self, batch_size, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.0):
atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"]) atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"])
atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch) atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch)
...@@ -75,7 +77,8 @@ class Linker(Module): ...@@ -75,7 +77,8 @@ class Linker(Module):
truth_links_batch = truth_links_batch.permute(1, 0, 2) truth_links_batch = truth_links_batch.permute(1, 0, 2)
# Construction tensor dataset # Construction tensor dataset
dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch) dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch, sentences_tokens,
sentences_mask)
if validation_rate > 0: if validation_rate > 0:
train_size = int(0.9 * len(dataset)) train_size = int(0.9 * len(dataset))
...@@ -109,8 +112,7 @@ class Linker(Module): ...@@ -109,8 +112,7 @@ class Linker(Module):
atoms_embedding = self.atoms_embedding(atoms_batch_tokenized) atoms_embedding = self.atoms_embedding(atoms_batch_tokenized)
# MHA ou LSTM avec sortie de BERT # MHA ou LSTM avec sortie de BERT
sents_embedding = torch.randn(32, self.max_len_sentence, self.dim_encoder) batch_size, _, _ = sents_embedding.shape
batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape
sents_mask = torch.randn(batch_size * self.nhead, self.max_atoms_in_sentence, self.max_len_sentence) sents_mask = torch.randn(batch_size * self.nhead, self.max_atoms_in_sentence, self.max_len_sentence)
atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask,
self.make_decoder_mask(atoms_batch_tokenized)) self.make_decoder_mask(atoms_batch_tokenized))
...@@ -143,12 +145,15 @@ class Linker(Module): ...@@ -143,12 +145,15 @@ class Linker(Module):
return torch.stack(link_weights) return torch.stack(link_weights)
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, batch_size=32, checkpoint=True, validate=True): def train_linker(self, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, validate=True):
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, validation_rate) training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
sentences_tokens, sentences_mask,
validation_rate)
for epoch_i in range(0, epochs): for epoch_i in range(0, epochs):
epoch_acc, epoch_loss = self.train_epoch(training_dataloader, validation_dataloader) epoch_acc, epoch_loss = self.train_epoch(training_dataloader, validation_dataloader, checkpoint, validate)
def train_epoch(self, training_dataloader, validation_dataloader, checkpoint=True, validate=True): def train_epoch(self, training_dataloader, validation_dataloader, checkpoint=True, validate=True):
...@@ -163,15 +168,16 @@ class Linker(Module): ...@@ -163,15 +168,16 @@ class Linker(Module):
batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
# batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu") batch_sentences_tokens = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
batch_sentences_mask = batch[4].to("cuda" if torch.cuda.is_available() else "cpu")
self.optimizer.zero_grad() self.optimizer.zero_grad()
# get sentence embedding from BERT which is already trained # get sentence embedding from BERT which is already trained
# sentences_embedding = supertagger(batch_sentences) logits, sentences_embedding = self.Supertagger.foward(batch_sentences_tokens, batch_sentences_mask)
# Run the kinker on the categories predictions # Run the kinker on the categories predictions
logits_predictions = self(batch_atoms, batch_polarity, []) logits_predictions = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask)
linker_loss = self.cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links) linker_loss = self.cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links)
# Perform a backward pass to calculate the gradients. # Perform a backward pass to calculate the gradients.
...@@ -256,9 +262,11 @@ class Linker(Module): ...@@ -256,9 +262,11 @@ class Linker(Module):
batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
# batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu") batch_sentences_tokens = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
batch_sentences_mask = batch[4].to("cuda" if torch.cuda.is_available() else "cpu")
logits_axiom_links_pred = self.forward(batch_atoms, batch_polarity, []) logits_axiom_links_pred = self.forward(batch_atoms, batch_polarity, batch_sentences_tokens,
batch_sentences_mask)
logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3) logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3)
axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3) axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3)
......
...@@ -2,7 +2,7 @@ from torch import Tensor ...@@ -2,7 +2,7 @@ from torch import Tensor
from torch.nn import (Dropout, LayerNorm, Module, MultiheadAttention) from torch.nn import (Dropout, LayerNorm, Module, MultiheadAttention)
from Configuration import Configuration from Configuration import Configuration
from utils_linker import FFN from Linker.utils_linker import FFN
class AttentionDecoderLayer(Module): class AttentionDecoderLayer(Module):
...@@ -35,8 +35,6 @@ class AttentionDecoderLayer(Module): ...@@ -35,8 +35,6 @@ class AttentionDecoderLayer(Module):
# init params # init params
dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder']) dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder']) dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
nhead = int(Configuration.modelLinkerConfig['nhead']) nhead = int(Configuration.modelLinkerConfig['nhead'])
dropout = float(Configuration.modelLinkerConfig['dropout']) dropout = float(Configuration.modelLinkerConfig['dropout'])
dim_feedforward = int(Configuration.modelLinkerConfig['dim_feedforward']) dim_feedforward = int(Configuration.modelLinkerConfig['dim_feedforward'])
......
...@@ -3,8 +3,8 @@ import regex ...@@ -3,8 +3,8 @@ import regex
import torch import torch
from torch.nn import Sequential, Linear, Dropout, GELU from torch.nn import Sequential, Linear, Dropout, GELU
from torch.nn import Module from torch.nn import Module
from atom_map import atom_map from Linker.atom_map import atom_map
from ..utils import pad_sequence from utils import pad_sequence
class FFN(Module): class FFN(Module):
......
...@@ -3,4 +3,5 @@ transformers==4.16.2 ...@@ -3,4 +3,5 @@ transformers==4.16.2
torch==1.10.2 torch==1.10.2
huggingface-hub==0.4.0 huggingface-hub==0.4.0
pandas==1.4.1 pandas==1.4.1
sentencepiece sentencepiece
\ No newline at end of file git+https://gitlab.irit.fr/pnria/global-helper/deepgrail-rnn/
\ No newline at end of file
import os
import time
from datetime import datetime
import numpy as np
import torch import torch
from torch.optim import AdamW
from torch.utils.data import Dataset, TensorDataset, random_split
from transformers import (get_cosine_schedule_with_warmup)
from Configuration import Configuration from Configuration import Configuration
from Linker.AtomTokenizer import AtomTokenizer
from Linker.Linker import Linker from Linker.Linker import Linker
from Linker.atom_map import atom_map from Supertagger.SuperTagger.SuperTagger import SuperTagger
from Linker.utils_linker import get_axiom_links, get_atoms_batch, find_pos_neg_idexes from utils import read_csv_pgbar
from Linker.eval import SinkhornLoss
from utils import format_time, read_csv_pgbar
torch.cuda.empty_cache() torch.cuda.empty_cache()
# region ParamsModel
max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
# endregion ParamsModel
# region ParamsTraining
batch_size = int(Configuration.modelTrainingConfig['batch_size']) batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 10 nb_sentences = batch_size * 10
epochs = int(Configuration.modelTrainingConfig['epoch']) epochs = int(Configuration.modelTrainingConfig['epoch'])
seed_val = int(Configuration.modelTrainingConfig['seed_val'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
# endregion ParamsTraining
# region Data loader
file_path_axiom_links = 'Datasets/aa1_links_dataset_links.csv' file_path_axiom_links = 'Datasets/aa1_links_dataset_links.csv'
df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
sentences_batch = df_axiom_links["Sentences"] sentences_batch = df_axiom_links["Sentences"].tolist()
atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"])
atom_tokenizer = AtomTokenizer(atom_map, max_atoms_in_sentence)
atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch)
print("atoms_tokens", atoms_batch_tokenized.shape)
atoms_polarity_batch = find_pos_neg_idexes(max_atoms_in_sentence, df_axiom_links["sub_tree"])
print("atoms_polarity_batch", atoms_polarity_batch.shape)
torch.set_printoptions(edgeitems=20)
truth_links_batch = get_axiom_links(max_atoms_in_one_type, atoms_polarity_batch, df_axiom_links["sub_tree"])
truth_links_batch = truth_links_batch.permute(1, 0, 2)
print("truth_links_batch", truth_links_batch.shape)
print("sentence", sentences_batch[14])
print("categories ", df_axiom_links["sub_tree"][14])
print("atoms_batch", atoms_batch[14])
print("atoms_polarity_batch", atoms_polarity_batch[14])
print(" truth_links_batch example on a sentence class txt", truth_links_batch[14][16])
# Construction tensor dataset
dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch)
# Calculate the number of samples to include in each set.
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
# Divide the dataset by randomly selecting samples.
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# endregion Data loader
# region Models
# supertagger = SuperTagger()
# supertagger.load_weights("models/model_check.pt")
linker = Linker()
# endregion Models
# region Fit tunning
# Optimizer
optimizer_linker = AdamW(linker.parameters(),
weight_decay=1e-5,
lr=learning_rate)
# Create the learning rate scheduler.
scheduler_linker = get_cosine_schedule_with_warmup(optimizer_linker,
num_warmup_steps=0,
num_training_steps=100)
# Loss
cross_entropy_loss = SinkhornLoss()
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
torch.autograd.set_detect_anomaly(True)
# endregion Fit tunning
# region Train
# Measure the total training time for the whole run.
total_t0 = time.time()
validate = True
checkpoint = True
def run_epochs(epochs):
# For each epoch...
for epoch_i in range(0, epochs):
# ========================================
# Training
# ========================================
# Perform one full pass over the training set.
print("")
print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
print('Training...')
# Measure how long the training epoch takes.
t0 = time.time()
# Reset the total loss for this epoch.
total_train_loss = 0
linker.train()
# For each batch of training data...
for step, batch in enumerate(training_dataloader):
# Unpack this training batch from our dataloader
batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
# batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
optimizer_linker.zero_grad()
# get sentence embedding from BERT which is already trained
# sentences_embedding = supertagger(batch_sentences)
# Run the kinker on the categories predictions
logits_predictions = linker(batch_atoms, batch_polarity, [])
linker_loss = cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links)
# Perform a backward pass to calculate the gradients.
total_train_loss += float(linker_loss)
linker_loss.backward()
# This is to help prevent the "exploding gradients" problem.
# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)
# Update parameters and take a step using the computed gradient.
optimizer_linker.step()
scheduler_linker.step()
avg_train_loss = total_train_loss / len(training_dataloader)
# Measure how long this epoch took.
training_time = format_time(time.time() - t0)
if checkpoint:
checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
linker.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt'))
if validate: supertagger = SuperTagger()
linker.eval() supertagger.load_weights("models/model_supertagger.pt")
with torch.no_grad():
print("Start eval")
accuracy, loss = linker.eval_epoch(validation_dataloader, cross_entropy_loss)
print("")
print(" Average accuracy on epoch: {0:.2f}".format(accuracy))
print(" Average loss on epoch: {0:.2f}".format(loss))
print("") sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
print(" Average training loss: {0:.2f}".format(avg_train_loss))
print(" Training epcoh took: {:}".format(training_time))
print("Linker")
linker = Linker(supertagger)
run_epochs(epochs) print("Linker Training")
# endregion Train linker.train_linker(df_axiom_links, sents_tokenized, sents_mask, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, validate=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment