Skip to content
Snippets Groups Projects
Commit fd524d4e authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

adding training methods

parent 154eabc1
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
import os
from datetime import datetime
import torch import torch
from torch.nn import Sequential, LayerNorm, Dropout from torch.nn import Sequential, LayerNorm, Dropout
from torch.nn import Module from torch.nn import Module
import torch.nn.functional as F import torch.nn.functional as F
import sys import sys
from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split
from transformers import get_cosine_schedule_with_warmup
from Configuration import Configuration from Configuration import Configuration
from AtomEmbedding import AtomEmbedding from AtomEmbedding import AtomEmbedding
from AtomTokenizer import AtomTokenizer from AtomTokenizer import AtomTokenizer
from MHA import AttentionDecoderLayer from MHA import AttentionDecoderLayer
from atom_map import atom_map from atom_map import atom_map
from Sinkhorn import sinkhorn_fn_no_exp as sinkhorn from Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN from utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links
from eval import mesure_accuracy from eval import mesure_accuracy, SinkhornLoss
from ..utils import pad_sequence from ..utils import pad_sequence
...@@ -27,6 +35,10 @@ class Linker(Module): ...@@ -27,6 +35,10 @@ class Linker(Module):
self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type']) self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 10
self.epochs = int(Configuration.modelTrainingConfig['epoch'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
self.dropout = Dropout(0.1) self.dropout = Dropout(0.1)
self.device = "" self.device = ""
...@@ -47,6 +59,41 @@ class Linker(Module): ...@@ -47,6 +59,41 @@ class Linker(Module):
LayerNorm(self.dim_embedding_atoms, eps=1e-12) LayerNorm(self.dim_embedding_atoms, eps=1e-12)
) )
self.cross_entropy_loss = SinkhornLoss()
self.optimizer = AdamW(self.parameters(),
weight_decay=1e-5,
lr=learning_rate)
self.scheduler = get_cosine_schedule_with_warmup(self.optimizer,
num_warmup_steps=0,
num_training_steps=100)
def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.0):
atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"])
atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch)
atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["sub_tree"])
torch.set_printoptions(edgeitems=20)
truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
df_axiom_links["sub_tree"])
truth_links_batch = truth_links_batch.permute(1, 0, 2)
# Construction tensor dataset
dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch)
if validation_rate > 0:
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
else:
validation_dataloader = None
train_dataset = dataset
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
return training_dataloader, validation_dataloader
def make_decoder_mask(self, atoms_token): def make_decoder_mask(self, atoms_token):
decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64) decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64)
decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0 decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0
...@@ -101,6 +148,63 @@ class Linker(Module): ...@@ -101,6 +148,63 @@ class Linker(Module):
return torch.stack(link_weights) return torch.stack(link_weights)
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, batch_size=32):
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, validation_rate)
epochs = epochs - self.epochs
self.train()
for epoch_i in range(0, epochs):
epoch_acc, epoch_loss = self.__train_epoch(training_dataloader, validation_dataloader)
def __train_epoch(self, training_dataloader, validation_dataloader, checkpoint=True, validate=True):
# Reset the total loss for this epoch.
epoch_loss = 0
self.train()
# For each batch of training data...
for step, batch in enumerate(training_dataloader):
# Unpack this training batch from our dataloader
batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
# batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
self.optimizer.zero_grad()
# get sentence embedding from BERT which is already trained
# sentences_embedding = supertagger(batch_sentences)
# Run the kinker on the categories predictions
logits_predictions = self(batch_atoms, batch_polarity, [])
linker_loss = self.cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links)
# Perform a backward pass to calculate the gradients.
epoch_loss += float(linker_loss)
linker_loss.backward()
# This is to help prevent the "exploding gradients" problem.
# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)
# Update parameters and take a step using the computed gradient.
self.optimizer.step()
self.scheduler.step()
avg_train_loss = epoch_loss / len(training_dataloader)
if checkpoint:
checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
self.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt'))
if validate:
self.eval()
with torch.no_grad():
accuracy, average_test_loss = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
return accuracy, avg_train_loss
def predict(self, categories, sents_embedding, sents_mask=None): def predict(self, categories, sents_embedding, sents_mask=None):
r''' r'''
Parameters : Parameters :
......
...@@ -25,7 +25,7 @@ def sub_tree_line(line_with_data: str): ...@@ -25,7 +25,7 @@ def sub_tree_line(line_with_data: str):
for word_with_data in line_list: for word_with_data in line_list:
w, t = sub_tree_word(word_with_data) w, t = sub_tree_word(word_with_data)
sentence += ' ' + w sentence += ' ' + w
if t not in ["\\", "/", "let"] and len(t)>0: if t not in ["\\", "/", "let"] and len(t) > 0:
sub_trees.append([t]) sub_trees.append([t])
"""if ('ppp' in list(itertools.chain(*sub_trees))): """if ('ppp' in list(itertools.chain(*sub_trees))):
print(sentence)""" print(sentence)"""
...@@ -35,17 +35,9 @@ def sub_tree_line(line_with_data: str): ...@@ -35,17 +35,9 @@ def sub_tree_line(line_with_data: str):
def Txt_to_csv(file_name: str): def Txt_to_csv(file_name: str):
file = open(file_name, "r", encoding="utf8") file = open(file_name, "r", encoding="utf8")
text = file.readlines() text = file.readlines()
sub = [sub_tree_line(data) for data in text] sub = [sub_tree_line(data) for data in text]
df = pd.DataFrame(data=sub, columns=['Sentences', 'sub_tree']) df = pd.DataFrame(data=sub, columns=['Sentences', 'sub_tree'])
df.to_csv("../Datasets/" + file_name[:-4] + "_dataset_links.csv", index=False) df.to_csv("../Datasets/" + file_name[:-4] + "_dataset_links.csv", index=False)
Txt_to_csv("aa1_links.txt") Txt_to_csv("aa1_links.txt")
"""trees = df['sub_tree']
trees_flat = set(list(itertools.chain(*list(itertools.chain(*trees)))))
fruit_dictionary = dict(zip(list(trees_flat), range(len(list(trees_flat)))))
print(fruit_dictionary)"""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment