Skip to content
Snippets Groups Projects
Commit fd524d4e authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

adding training methods

parent 154eabc1
Branches
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
import os
from datetime import datetime
import torch
from torch.nn import Sequential, LayerNorm, Dropout
from torch.nn import Module
import torch.nn.functional as F
import sys
from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split
from transformers import get_cosine_schedule_with_warmup
from Configuration import Configuration
from AtomEmbedding import AtomEmbedding
from AtomTokenizer import AtomTokenizer
from MHA import AttentionDecoderLayer
from atom_map import atom_map
from Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN
from eval import mesure_accuracy
from utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links
from eval import mesure_accuracy, SinkhornLoss
from ..utils import pad_sequence
......@@ -27,6 +35,10 @@ class Linker(Module):
self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 10
self.epochs = int(Configuration.modelTrainingConfig['epoch'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
self.dropout = Dropout(0.1)
self.device = ""
......@@ -47,6 +59,41 @@ class Linker(Module):
LayerNorm(self.dim_embedding_atoms, eps=1e-12)
)
self.cross_entropy_loss = SinkhornLoss()
self.optimizer = AdamW(self.parameters(),
weight_decay=1e-5,
lr=learning_rate)
self.scheduler = get_cosine_schedule_with_warmup(self.optimizer,
num_warmup_steps=0,
num_training_steps=100)
def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.0):
atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"])
atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch)
atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["sub_tree"])
torch.set_printoptions(edgeitems=20)
truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
df_axiom_links["sub_tree"])
truth_links_batch = truth_links_batch.permute(1, 0, 2)
# Construction tensor dataset
dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch)
if validation_rate > 0:
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
else:
validation_dataloader = None
train_dataset = dataset
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
return training_dataloader, validation_dataloader
def make_decoder_mask(self, atoms_token):
decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64)
decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0
......@@ -101,6 +148,63 @@ class Linker(Module):
return torch.stack(link_weights)
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, batch_size=32):
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, validation_rate)
epochs = epochs - self.epochs
self.train()
for epoch_i in range(0, epochs):
epoch_acc, epoch_loss = self.__train_epoch(training_dataloader, validation_dataloader)
def __train_epoch(self, training_dataloader, validation_dataloader, checkpoint=True, validate=True):
# Reset the total loss for this epoch.
epoch_loss = 0
self.train()
# For each batch of training data...
for step, batch in enumerate(training_dataloader):
# Unpack this training batch from our dataloader
batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
# batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
self.optimizer.zero_grad()
# get sentence embedding from BERT which is already trained
# sentences_embedding = supertagger(batch_sentences)
# Run the kinker on the categories predictions
logits_predictions = self(batch_atoms, batch_polarity, [])
linker_loss = self.cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links)
# Perform a backward pass to calculate the gradients.
epoch_loss += float(linker_loss)
linker_loss.backward()
# This is to help prevent the "exploding gradients" problem.
# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)
# Update parameters and take a step using the computed gradient.
self.optimizer.step()
self.scheduler.step()
avg_train_loss = epoch_loss / len(training_dataloader)
if checkpoint:
checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
self.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt'))
if validate:
self.eval()
with torch.no_grad():
accuracy, average_test_loss = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
return accuracy, avg_train_loss
def predict(self, categories, sents_embedding, sents_mask=None):
r'''
Parameters :
......
......@@ -25,7 +25,7 @@ def sub_tree_line(line_with_data: str):
for word_with_data in line_list:
w, t = sub_tree_word(word_with_data)
sentence += ' ' + w
if t not in ["\\", "/", "let"] and len(t)>0:
if t not in ["\\", "/", "let"] and len(t) > 0:
sub_trees.append([t])
"""if ('ppp' in list(itertools.chain(*sub_trees))):
print(sentence)"""
......@@ -35,17 +35,9 @@ def sub_tree_line(line_with_data: str):
def Txt_to_csv(file_name: str):
file = open(file_name, "r", encoding="utf8")
text = file.readlines()
sub = [sub_tree_line(data) for data in text]
df = pd.DataFrame(data=sub, columns=['Sentences', 'sub_tree'])
df.to_csv("../Datasets/" + file_name[:-4] + "_dataset_links.csv", index=False)
Txt_to_csv("aa1_links.txt")
"""trees = df['sub_tree']
trees_flat = set(list(itertools.chain(*list(itertools.chain(*trees)))))
fruit_dictionary = dict(zip(list(trees_flat), range(len(list(trees_flat)))))
print(fruit_dictionary)"""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment