Skip to content
Snippets Groups Projects
Commit d891f64e authored by Caroline de Pourtalès's avatar Caroline de Pourtalès :speech_balloon:
Browse files

contrect code and simplify

parent 5ea1ec72
Branches
No related tags found
No related merge requests found
...@@ -28,7 +28,7 @@ class Linker(Module): ...@@ -28,7 +28,7 @@ class Linker(Module):
# region initialization # region initialization
def __init__(self, supertagger_path_model): def __init__(self):
super(Linker, self).__init__() super(Linker, self).__init__()
# region parameters # region parameters
...@@ -58,12 +58,6 @@ class Linker(Module): ...@@ -58,12 +58,6 @@ class Linker(Module):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# endregion # endregion
# SuperTagger for categories
supertagger = SuperTagger()
supertagger.load_weights(supertagger_path_model)
self.Supertagger = supertagger
self.Supertagger.model.to(self.device)
# Atoms embedding # Atoms embedding
self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
self.atom_map_redux = atom_map_redux self.atom_map_redux = atom_map_redux
...@@ -118,53 +112,6 @@ class Linker(Module): ...@@ -118,53 +112,6 @@ class Linker(Module):
#endregion #endregion
# region data
def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
r"""
Args:
batch_size : int
df_axiom_links pandas DataFrame
validation_rate
Returns:
the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask
"""
print("Start preprocess Data")
sentences_batch = df_axiom_links["X"].str.strip().tolist()
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
atoms_batch, polarities, num_atoms_per_word = get_GOAL(self.max_len_sentence, df_axiom_links)
atoms_polarity_batch = pad_sequence(
[torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
max_len=self.max_atoms_in_sentence, padding_value=0)
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
df_axiom_links["Y"])
truth_links_batch = truth_links_batch.permute(1, 0, 2)
# Construction tensor dataset
dataset = TensorDataset(num_atoms_per_word, atoms_batch_tokenized, pos_idx, neg_idx, truth_links_batch,
sentences_tokens, sentences_mask)
if validation_rate > 0.0:
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
else:
validation_dataloader = None
train_dataset = dataset
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
print("End preprocess Data")
return training_dataloader, validation_dataloader
#endregion
# region training # region training
def make_sinkhorn_inputs(self, bsd_tensor, positional_ids, atom_type): def make_sinkhorn_inputs(self, bsd_tensor, positional_ids, atom_type):
...@@ -229,56 +176,7 @@ class Linker(Module): ...@@ -229,56 +176,7 @@ class Linker(Module):
return F.log_softmax(link_weights, dim=3) return F.log_softmax(link_weights, dim=3)
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, def train_epoch(self, training_dataloader, Supertagger):
batch_size=32, checkpoint=True, tensorboard=False):
r"""
Args:
df_axiom_links : pandas dataFrame containing the atoms anoted with _i
validation_rate : float
epochs : int
batch_size : int
checkpoint : boolean
tensorboard : boolean
Returns:
Final accuracy and final loss
"""
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
validation_rate)
if checkpoint or tensorboard:
checkpoint_dir, writer = output_create_dir()
for epoch_i in range(epochs):
print("")
print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
print('Training...')
avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader)
print("")
print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
if validation_rate > 0.0:
loss_test, accuracy_test = self.eval_epoch(validation_dataloader)
print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')
if checkpoint:
self.__checkpoint_save(
path=os.path.join("Output", 'linker.pt'))
if tensorboard:
writer.add_scalars(f'Accuracy', {
'Train': avg_accuracy_train}, epoch_i)
writer.add_scalars(f'Loss', {
'Train': avg_train_loss}, epoch_i)
if validation_rate > 0.0:
writer.add_scalars(f'Accuracy', {
'Validation': accuracy_test}, epoch_i)
writer.add_scalars(f'Loss', {
'Validation': loss_test}, epoch_i)
print('\n')
def train_epoch(self, training_dataloader):
r""" Train epoch r""" Train epoch
Args: Args:
...@@ -309,12 +207,11 @@ class Linker(Module): ...@@ -309,12 +207,11 @@ class Linker(Module):
self.optimizer.zero_grad() self.optimizer.zero_grad()
# get sentence embedding from BERT which is already trained # get sentence embedding from BERT which is already trained
output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) output = Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
# Run the Linker on the atoms # Run the Linker on the atoms
logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx,
output['word_embedding']) output['word_embedding'])
linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links) linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
# Perform a backward pass to calculate the gradients. # Perform a backward pass to calculate the gradients.
epoch_loss += float(linker_loss) epoch_loss += float(linker_loss)
...@@ -342,33 +239,7 @@ class Linker(Module): ...@@ -342,33 +239,7 @@ class Linker(Module):
# region evaluation # region evaluation
def eval_batch(self, batch): def eval_epoch(self, dataloader, Supertagger):
batch_num_atoms = batch[0].to(self.device)
batch_atoms_tok = batch[1].to(self.device)
batch_pos_idx = batch[2].to(self.device)
batch_neg_idx = batch[3].to(self.device)
batch_true_links = batch[4].to(self.device)
batch_sentences_tokens = batch[5].to(self.device)
batch_sentences_mask = batch[6].to(self.device)
output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, output[
'word_embedding']) # atom_vocab, batch_size, max atoms in one type, max atoms in one type
axiom_links_pred = torch.argmax(logits_predictions, dim=3) # atom_vocab, batch_size, max atoms in one type
print('\n')
print(batch_true_links)
print("Les vrais liens de la catégorie n : ", batch_true_links[0][2][:100])
print("Les prédictions : ", axiom_links_pred[2][0][:100])
print('\n')
accuracy = measure_accuracy(batch_true_links, axiom_links_pred)
loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
return loss, accuracy
def eval_epoch(self, dataloader):
r"""Average the evaluation of all the batch. r"""Average the evaluation of all the batch.
Args: Args:
...@@ -379,107 +250,25 @@ class Linker(Module): ...@@ -379,107 +250,25 @@ class Linker(Module):
loss_average = 0 loss_average = 0
with torch.no_grad(): with torch.no_grad():
for step, batch in enumerate(dataloader): for step, batch in enumerate(dataloader):
loss, accuracy = self.eval_batch(batch) batch_num_atoms = batch[0].to(self.device)
accuracy_average += accuracy batch_atoms_tok = batch[1].to(self.device)
loss_average += float(loss) batch_pos_idx = batch[2].to(self.device)
batch_neg_idx = batch[3].to(self.device)
return loss_average / len(dataloader), accuracy_average / len(dataloader) batch_true_links = batch[4].to(self.device)
batch_sentences_tokens = batch[5].to(self.device)
#endregion batch_sentences_mask = batch[6].to(self.device)
#region prediction
def predict_with_categories(self, sentence, categories):
r""" Predict the links from a sentence and its categories
Args :
sentence : list of words composing the sentence
categories : list of categories (tags) of each word
Return :
links : links prediction
"""
self.eval()
with torch.no_grad():
self.cpu()
self.device = torch.device("cpu")
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors([sentence])
nb_sentence, len_sentence = sentences_tokens.shape
atoms = get_atoms_batch([categories])
atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms)
polarities = find_pos_neg_idexes([categories])
polarities = pad_sequence(
[torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
max_len=self.max_atoms_in_sentence, padding_value=0)
num_atoms_per_word = get_num_atoms_batch([categories], len_sentence)
pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type)
neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type)
output = self.Supertagger.forward(sentences_tokens, sentences_mask)
logits_predictions = self(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embedding'])
axiom_links_pred = torch.argmax(logits_predictions, dim=3)
return axiom_links_pred
def predict_without_categories(self, sentence):
r""" Predict the links from a sentence
Args :
sentence : list of words composing the sentence
Return :
categories : the supertags predicted
links : links prediction
"""
self.eval()
with torch.no_grad():
self.cpu()
self.device = torch.device("cpu")
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentence)
nb_sentence, len_sentence = sentences_tokens.shape
hidden_state, categories = self.Supertagger.predict(sentence)
output = self.Supertagger.forward(sentences_tokens, sentences_mask) output = Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
atoms = get_atoms_batch(categories)
atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms)
polarities = find_pos_neg_idexes(categories) logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, output['word_embedding']) # atom_vocab, batch_size, max atoms in one type, max atoms in one type
polarities = pad_sequence( axiom_links_pred = torch.argmax(logits_predictions, dim=3) # atom_vocab, batch_size, max atoms in one type
[torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
max_len=self.max_atoms_in_sentence, padding_value=0)
num_atoms_per_word = get_num_atoms_batch(categories, len_sentence) accuracy = measure_accuracy(batch_true_links, axiom_links_pred)
loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type) accuracy_average += accuracy
neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type) loss_average += float(loss)
logits_predictions = self(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embedding']) return loss_average / len(dataloader), accuracy_average / len(dataloader)
axiom_links_pred = torch.argmax(logits_predictions, dim=3)
return categories, axiom_links_pred
#endregion #endregion
def __checkpoint_save(self, path='/linker.pt'):
"""
@param path:
"""
self.cpu()
torch.save({
'atom_encoder': self.atom_encoder.state_dict(),
'position_encoder': self.position_encoder.state_dict(),
'transformer': self.transformer.state_dict(),
'linker_encoder': self.linker_encoder.state_dict(),
'pos_transformation': self.pos_transformation.state_dict(),
'neg_transformation': self.neg_transformation.state_dict(),
'cross_entropy_loss': self.cross_entropy_loss.state_dict(),
'optimizer': self.optimizer,
}, path)
self.to(self.device)
import numpy as np
import torch import torch
from torch.nn import Module from torch.nn import Module
from torch.nn.functional import nll_loss from torch.nn.functional import nll_loss
from Linker.atom_map import atom_map, atom_map_redux from Linker.atom_map import atom_map, atom_map_redux
...@@ -12,8 +15,16 @@ class SinkhornLoss(Module): ...@@ -12,8 +15,16 @@ class SinkhornLoss(Module):
super(SinkhornLoss, self).__init__() super(SinkhornLoss, self).__init__()
def forward(self, predictions, truths): def forward(self, predictions, truths):
return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1) sum = 0
for link, perm in zip(predictions, truths.permute(1, 0, 2))) # for each categorie of atom (txt, np ...)
for link, perm in zip(predictions, truths.permute(1, 0, 2)):
# test if there are true links in this categorie
if 0 in perm.flatten():
# mean nll loss of the categorie current calculated on the whole batch
it = nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1)
# sum it to the current total loss
sum+=it
return sum
def measure_accuracy(batch_true_links, axiom_links_pred): def measure_accuracy(batch_true_links, axiom_links_pred):
......
...@@ -8,13 +8,14 @@ from torch.utils.data import TensorDataset, random_split ...@@ -8,13 +8,14 @@ from torch.utils.data import TensorDataset, random_split
from tqdm import tqdm from tqdm import tqdm
from Configuration import Configuration from Configuration import Configuration
from NeuralProofNet.utils_proofnet import get_info_for_tagger
from SuperTagger import SuperTagger
from Linker import Linker from Linker import Linker
from Linker.eval import measure_accuracy, SinkhornLoss from Linker.eval import measure_accuracy, SinkhornLoss
from Linker.utils_linker import get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch, get_neg_idx from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_neg_idx, get_atoms_batch, \
from NeuralProofNet.utils_proofnet import get_info_for_tagger find_pos_neg_idexes, get_num_atoms_batch, generate_square_subsequent_mask
from utils import pad_sequence, format_time, output_create_dir from utils import pad_sequence, format_time, output_create_dir
class NeuralProofNet(Module): class NeuralProofNet(Module):
def __init__(self, supertagger_path_model, linker_path_model=None): def __init__(self, supertagger_path_model, linker_path_model=None):
...@@ -28,7 +29,13 @@ class NeuralProofNet(Module): ...@@ -28,7 +29,13 @@ class NeuralProofNet(Module):
self.max_atoms_in_one_type = int(datasetConfig['max_atoms_in_one_type']) self.max_atoms_in_one_type = int(datasetConfig['max_atoms_in_one_type'])
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
linker = Linker(supertagger_path_model) # SuperTagger for categories
supertagger = SuperTagger()
supertagger.load_weights(supertagger_path_model)
self.Supertagger = supertagger
self.Supertagger.model.to(self.device)
linker = Linker()
if linker_path_model is not None: if linker_path_model is not None:
linker.load_weights(linker_path_model) linker.load_weights(linker_path_model)
self.linker = linker self.linker = linker
...@@ -41,12 +48,6 @@ class NeuralProofNet(Module): ...@@ -41,12 +48,6 @@ class NeuralProofNet(Module):
self.to(self.device) self.to(self.device)
def __pretrain_linker__(self, df_axiom_links, pretrain_linker_epochs, batch_size, checkpoint=False, tensorboard=True):
print("\nLinker Pre-Training\n")
self.linker.train_linker(df_axiom_links, validation_rate=0.05, epochs=pretrain_linker_epochs,
batch_size=batch_size, checkpoint=checkpoint, tensorboard=tensorboard)
print("\nEND Linker Pre-Training\n")
def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1): def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
r""" r"""
Args: Args:
...@@ -54,26 +55,31 @@ class NeuralProofNet(Module): ...@@ -54,26 +55,31 @@ class NeuralProofNet(Module):
df_axiom_links pandas DataFrame df_axiom_links pandas DataFrame
validation_rate validation_rate
Returns: Returns:
the training dataloader and the validation dataloader. They contain the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask
""" """
print("Start preprocess Data") print("Start preprocess Data")
sentences_batch = df_axiom_links["X"].str.strip().tolist() sentences_batch = df_axiom_links["X"].str.strip().tolist()
sentences_tokens, sentences_mask = self.linker.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
_, polarities, _ = get_GOAL(self.max_len_sentence, df_axiom_links) atoms_batch, polarities, num_atoms_per_word = get_GOAL(self.max_len_sentence, df_axiom_links)
atoms_polarity_batch = pad_sequence( atoms_polarity_batch = pad_sequence(
[torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
max_len=self.max_atoms_in_sentence, padding_value=0) max_len=self.max_atoms_in_sentence, padding_value=0)
atoms_batch_tokenized = self.linker.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch, pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.linker.max_atoms_in_one_type)
neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.linker.max_atoms_in_one_type)
truth_links_batch = get_axiom_links(self.linker.max_atoms_in_one_type, atoms_polarity_batch,
df_axiom_links["Y"]) df_axiom_links["Y"])
truth_links_batch = truth_links_batch.permute(1, 0, 2) truth_links_batch = truth_links_batch.permute(1, 0, 2)
# Construction tensor dataset # Construction tensor dataset
dataset = TensorDataset(truth_links_batch, sentences_tokens, sentences_mask) dataset = TensorDataset(num_atoms_per_word, atoms_batch_tokenized, pos_idx, neg_idx, truth_links_batch,
sentences_tokens, sentences_mask)
if validation_rate > 0.0: if validation_rate > 0.0:
train_size = int(0.9 * len(dataset)) train_size = int((1-validation_rate) * len(dataset))
val_size = len(dataset) - train_size val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False) validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
...@@ -85,13 +91,15 @@ class NeuralProofNet(Module): ...@@ -85,13 +91,15 @@ class NeuralProofNet(Module):
print("End preprocess Data") print("End preprocess Data")
return training_dataloader, validation_dataloader return training_dataloader, validation_dataloader
# region training
def forward(self, batch_sentences_tokens, batch_sentences_mask): def forward(self, batch_sentences_tokens, batch_sentences_mask):
# get sentence embedding from BERT which is already trained # get sentence embedding from BERT which is already trained
output = self.linker.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
last_hidden_state = output['logit'] last_hidden_state = output['logit']
pred_categories = torch.argmax(torch.softmax(last_hidden_state, dim=2), dim=2) pred_categories = torch.argmax(torch.softmax(last_hidden_state, dim=2), dim=2)
pred_categories = self.linker.Supertagger.tags_tokenizer.convert_ids_to_tags(pred_categories) pred_categories = self.Supertagger.tags_tokenizer.convert_ids_to_tags(pred_categories)
# get information from tagger predictions # get information from tagger predictions
atoms_batch, polarities, batch_num_atoms_per_word = get_info_for_tagger(self.max_len_sentence, pred_categories) atoms_batch, polarities, batch_num_atoms_per_word = get_info_for_tagger(self.max_len_sentence, pred_categories)
...@@ -112,6 +120,49 @@ class NeuralProofNet(Module): ...@@ -112,6 +120,49 @@ class NeuralProofNet(Module):
return torch.log_softmax(logits_links, dim=3) return torch.log_softmax(logits_links, dim=3)
def pretrain_linker(self, training_dataloader, validation_dataloader, pretrain_linker_epochs, checkpoint=None, writer=None):
r"""
Args:
df_axiom_links : pandas dataFrame containing the atoms anoted with _i
validation_rate : float
epochs : int
batch_size : int
checkpoint : boolean
tensorboard : boolean
Returns:
Final accuracy and final loss
"""
for epoch_i in range(pretrain_linker_epochs):
print("")
print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, pretrain_linker_epochs))
print('Training...')
avg_train_loss, avg_accuracy_train, training_time = self.linker.train_epoch(training_dataloader, self.Supertagger)
print("")
print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
if validation_dataloader:
loss_test, accuracy_test = self.linker.eval_epoch(validation_dataloader, self.Supertagger)
print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')
if checkpoint:
self.__checkpoint_save(path='Output/linker.pt')
if writer:
writer.add_scalars(f'Accuracy', {
'Train': avg_accuracy_train}, epoch_i)
writer.add_scalars(f'Loss', {
'Train': avg_train_loss}, epoch_i)
if validation_dataloader :
writer.add_scalars(f'Accuracy', {
'Validation': accuracy_test}, epoch_i)
writer.add_scalars(f'Loss', {
'Validation': loss_test}, epoch_i)
print('\n')
def train_neuralproofnet(self, df_axiom_links, validation_rate=0.1, epochs=20, pretrain_linker_epochs=0, def train_neuralproofnet(self, df_axiom_links, validation_rate=0.1, epochs=20, pretrain_linker_epochs=0,
batch_size=32, checkpoint=True, tensorboard=False): batch_size=32, checkpoint=True, tensorboard=False):
r""" r"""
...@@ -125,15 +176,20 @@ class NeuralProofNet(Module): ...@@ -125,15 +176,20 @@ class NeuralProofNet(Module):
Returns: Returns:
Final accuracy and final loss Final accuracy and final loss
""" """
# Pretrain the linker
self.__pretrain_linker__(df_axiom_links, pretrain_linker_epochs, batch_size)
# Start learning with output from tagger # Start learning with output from tagger
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
validation_rate) validation_rate)
if checkpoint or tensorboard: if checkpoint or tensorboard:
checkpoint_dir, writer = output_create_dir() checkpoint_dir, writer = output_create_dir()
# Pretrain the linker with the rights categories
if pretrain_linker_epochs >0 :
print("\nLinker Pre-Training\n")
self.pretrain_linker(training_dataloader, validation_dataloader, \
pretrain_linker_epochs, checkpoint, writer)
print("\nEND Linker Pre-Training\n")
# Train Linker with predicted categories from supertagger
for epoch_i in range(epochs): for epoch_i in range(epochs):
print("") print("")
print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs)) print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
...@@ -153,14 +209,14 @@ class NeuralProofNet(Module): ...@@ -153,14 +209,14 @@ class NeuralProofNet(Module):
if tensorboard: if tensorboard:
writer.add_scalars(f'Accuracy', { writer.add_scalars(f'Accuracy', {
'Train': avg_accuracy_train}, epoch_i) 'Train': avg_accuracy_train}, pretrain_linker_epochs + epoch_i)
writer.add_scalars(f'Loss', { writer.add_scalars(f'Loss', {
'Train': avg_train_loss}, epoch_i) 'Train': avg_train_loss}, pretrain_linker_epochs + epoch_i)
if validation_rate > 0.0: if validation_rate > 0.0:
writer.add_scalars(f'Accuracy', { writer.add_scalars(f'Accuracy', {
'Validation': accuracy_test}, epoch_i) 'Validation': accuracy_test}, pretrain_linker_epochs + epoch_i)
writer.add_scalars(f'Loss', { writer.add_scalars(f'Loss', {
'Validation': loss_test}, epoch_i) 'Validation': loss_test}, pretrain_linker_epochs + epoch_i)
print('\n') print('\n')
...@@ -184,9 +240,9 @@ class NeuralProofNet(Module): ...@@ -184,9 +240,9 @@ class NeuralProofNet(Module):
with tqdm(training_dataloader, unit="batch") as tepoch: with tqdm(training_dataloader, unit="batch") as tepoch:
for batch in tepoch: for batch in tepoch:
# Unpack this training batch from our dataloader # Unpack this training batch from our dataloader
batch_true_links = batch[0].to(self.device) batch_true_links = batch[4].to(self.device)
batch_sentences_tokens = batch[1].to(self.device) batch_sentences_tokens = batch[5].to(self.device)
batch_sentences_mask = batch[2].to(self.device) batch_sentences_mask = batch[6].to(self.device)
self.linker_optimizer.zero_grad() self.linker_optimizer.zero_grad()
...@@ -215,25 +271,10 @@ class NeuralProofNet(Module): ...@@ -215,25 +271,10 @@ class NeuralProofNet(Module):
avg_accuracy_train = accuracy_train / len(training_dataloader) avg_accuracy_train = accuracy_train / len(training_dataloader)
return avg_train_loss, avg_accuracy_train, training_time return avg_train_loss, avg_accuracy_train, training_time
#endregion
def eval_batch(self, batch): # region evaluation
batch_true_links = batch[0].to(self.device)
batch_sentences_tokens = batch[1].to(self.device)
batch_sentences_mask = batch[2].to(self.device)
logits_predictions_links = self(batch_sentences_tokens, batch_sentences_mask)
axiom_links_pred = torch.argmax(logits_predictions_links,
dim=3) # atom_vocab, batch_size, max atoms in one type
print('\n')
print("Les vrais liens de la catégorie n : ", batch_true_links[0][2][:100])
print("Les prédictions : ", axiom_links_pred[2][0][:100])
print('\n')
accuracy = measure_accuracy(batch_true_links, axiom_links_pred)
linker_loss = self.linker_loss(logits_predictions_links, batch_true_links)
return linker_loss, accuracy
def eval_epoch(self, dataloader): def eval_epoch(self, dataloader):
r"""Average the evaluation of all the batch. r"""Average the evaluation of all the batch.
...@@ -246,12 +287,24 @@ class NeuralProofNet(Module): ...@@ -246,12 +287,24 @@ class NeuralProofNet(Module):
loss_average = 0 loss_average = 0
with torch.no_grad(): with torch.no_grad():
for step, batch in enumerate(dataloader): for step, batch in enumerate(dataloader):
loss, accuracy = self.eval_batch(batch) batch_true_links = batch[4].to(self.device)
batch_sentences_tokens = batch[5].to(self.device)
batch_sentences_mask = batch[6].to(self.device)
logits_predictions_links = self(batch_sentences_tokens, batch_sentences_mask)
axiom_links_pred = torch.argmax(logits_predictions_links,
dim=3) # atom_vocab, batch_size, max atoms in one type
accuracy = measure_accuracy(batch_true_links, axiom_links_pred)
linker_loss = self.linker_loss(logits_predictions_links, batch_true_links)
accuracy_average += accuracy accuracy_average += accuracy
loss_average += float(loss) loss_average += float(linker_loss)
return loss_average / len(dataloader), accuracy_average / len(dataloader) return loss_average / len(dataloader), accuracy_average / len(dataloader)
#endregion
def __checkpoint_save(self, path='/linker.pt'): def __checkpoint_save(self, path='/linker.pt'):
""" """
@param path: @param path:
...@@ -268,4 +321,83 @@ class NeuralProofNet(Module): ...@@ -268,4 +321,83 @@ class NeuralProofNet(Module):
'cross_entropy_loss': self.linker_loss.state_dict(), 'cross_entropy_loss': self.linker_loss.state_dict(),
'optimizer': self.linker_optimizer, 'optimizer': self.linker_optimizer,
}, path) }, path)
self.to(self.device) self.to(self.device)
\ No newline at end of file
#region prediction
def predict_with_categories(self, sentence, categories):
r""" Predict the links from a sentence and its categories
Args :
sentence : list of words composing the sentence
categories : list of categories (tags) of each word
Return :
links : links prediction
"""
self.eval()
with torch.no_grad():
self.cpu()
self.device = torch.device("cpu")
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentence)
nb_sentence, len_sentence = sentences_tokens.shape
atoms = get_atoms_batch(categories)
atoms_tokenized = self.linker.atoms_tokenizer.convert_batchs_to_ids(atoms)
polarities = find_pos_neg_idexes(categories)
polarities = pad_sequence(
[torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
max_len=self.max_atoms_in_sentence, padding_value=0)
num_atoms_per_word = get_num_atoms_batch(categories, len_sentence)
pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type)
neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type)
output = self.Supertagger.forward(sentences_tokens, sentences_mask)
logits_predictions = self.linker(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embedding'])
axiom_links_pred = torch.argmax(logits_predictions, dim=3)
return axiom_links_pred
def predict_without_categories(self, sentence):
r""" Predict the links from a sentence
Args :
sentence : list of words composing the sentence
Return :
categories : the supertags predicted
links : links prediction
"""
self.eval()
with torch.no_grad():
self.cpu()
self.device = torch.device("cpu")
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentence)
nb_sentence, len_sentence = sentences_tokens.shape
hidden_state, categories = self.Supertagger.predict(sentence)
output = self.Supertagger.forward(sentences_tokens, sentences_mask)
atoms = get_atoms_batch(categories)
atoms_tokenized = self.linker.atoms_tokenizer.convert_batchs_to_ids(atoms)
polarities = find_pos_neg_idexes(categories)
polarities = pad_sequence(
[torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
max_len=self.max_atoms_in_sentence, padding_value=0)
num_atoms_per_word = get_num_atoms_batch(categories, len_sentence)
pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type)
neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type)
logits_predictions = self.linker(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embedding'])
axiom_links_pred = torch.argmax(logits_predictions, dim=3)
return categories, axiom_links_pred
#endregion
\ No newline at end of file
...@@ -4,22 +4,21 @@ from postprocessing import draw_sentence_output ...@@ -4,22 +4,21 @@ from postprocessing import draw_sentence_output
if __name__== '__main__': if __name__== '__main__':
# region data # region data
a_s = ["( 1 ) parmi les huit \" partants \" acquis ou potentiels , MM. Lacombe , Koehler et Laroze ne sont pas membres du PCF ."] a_s = ["( 1 ) parmi les huit \" partants \" acquis ou potentiels , MM. Lacombe , Koehler et Laroze ne sont pas membres du PCF ."]
tags_s = ['let', 'dr(0,s,s)', 'let', 'dr(0,dr(0,s,s),np)', 'dr(0,np,n)', 'dr(0,n,n)', 'let', 'n', 'let', 'dl(0,n,n)', tags_s = [['let', 'dr(0,s,s)', 'let', 'dr(0,dr(0,s,s),np)', 'dr(0,np,n)', 'dr(0,n,n)', 'let', 'n', 'let', 'dl(0,n,n)',
'dr(0,dl(0,dl(0,n,n),dl(0,n,n)),dl(0,n,n))', 'dl(0,n,n)', 'let', 'dr(0,np,np)', 'np', 'dr(0,dl(0,np,np),np)', 'dr(0,dl(0,dl(0,n,n),dl(0,n,n)),dl(0,n,n))', 'dl(0,n,n)', 'let', 'dr(0,np,np)', 'np', 'dr(0,dl(0,np,np),np)',
'np', 'dr(0,dl(0,np,np),np)', 'np', 'dr(0,dl(0,np,s),dl(0,np,s))', 'dr(0,dl(0,np,s),np)', 'dl(1,s,s)', 'np', 'np', 'dr(0,dl(0,np,np),np)', 'np', 'dr(0,dl(0,np,s),dl(0,np,s))', 'dr(0,dl(0,np,s),np)', 'dl(1,s,s)', 'np',
'dr(0,dl(0,np,np),n)', 'n', 'dl(0,s,txt)'] 'dr(0,dl(0,np,np),n)', 'n', 'dl(0,s,txt)']]
# endregion # endregion
# region model # region model
model_tagger = "models/flaubert_super_98_V2_50e.pt" model_tagger = "models/flaubert_super_98_V2_50e.pt"
neuralproofnet = NeuralProofNet(model_tagger) neuralproofnet = NeuralProofNet(model_tagger)
model = "Output/linker.pt" model = "Output/saved_linker.pt"
neuralproofnet.linker.load_weights(model) neuralproofnet.linker.load_weights(model)
# endregion # endregion
linker = neuralproofnet.linker #categories, links = neuralproofnet.predict_without_categories(a_s)
categories, links = linker.predict_without_categories(a_s) links = neuralproofnet.predict_with_categories(a_s, tags_s)
#links = linker.predict_with_categories(a_s, tags_s)
idx=0 idx=0
draw_sentence_output(a_s[idx].split(" "), categories[idx], links[:,idx,:].numpy()) draw_sentence_output(a_s[idx].split(" "), tags_s[idx], links[:,idx,:].numpy())
...@@ -6,8 +6,8 @@ torch.cuda.empty_cache() ...@@ -6,8 +6,8 @@ torch.cuda.empty_cache()
# region data # region data
file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv' file_path_axiom_links = 'Datasets/gold_dataset_links.csv'
df_axiom_links = read_links_csv(file_path_axiom_links) df_axiom_links = read_links_csv(file_path_axiom_links)[:32]
# endregion # endregion
...@@ -16,7 +16,7 @@ print("#" * 20) ...@@ -16,7 +16,7 @@ print("#" * 20)
print("#" * 20) print("#" * 20)
model_tagger = "models/flaubert_super_98_V2_50e.pt" model_tagger = "models/flaubert_super_98_V2_50e.pt"
neural_proof_net = NeuralProofNet(model_tagger) neural_proof_net = NeuralProofNet(model_tagger)
neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0.1, epochs=25, pretrain_linker_epochs=20, batch_size=16, neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0, epochs=5, pretrain_linker_epochs=5, batch_size=16,
checkpoint=True, tensorboard=True) checkpoint=True, tensorboard=True)
print("#" * 20) print("#" * 20)
print("#" * 20) print("#" * 20)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment