diff --git a/Linker/Linker.py b/Linker/Linker.py index c4069215ed55374c4b035e9ad69553a13a7510a8..3090e6fc2ab8f13567f04ad73cd568ddd192d8ea 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -1,12 +1,9 @@ import os +import sys from datetime import datetime -import torch -from torch.nn import Sequential, LayerNorm, Dropout -from torch.nn import Module import torch.nn.functional as F -import sys - +from torch.nn import Sequential, LayerNorm, Dropout from torch.optim import AdamW from torch.utils.data import TensorDataset, random_split from torch.utils.tensorboard import SummaryWriter @@ -16,11 +13,12 @@ from Configuration import Configuration from Linker.AtomEmbedding import AtomEmbedding from Linker.AtomTokenizer import AtomTokenizer from Linker.MHA import AttentionDecoderLayer -from Linker.atom_map import atom_map from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn +from Linker.atom_map import atom_map +from Linker.eval import mesure_accuracy, SinkhornLoss from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \ get_neg_encoding_for_s_idx -from Linker.eval import mesure_accuracy, SinkhornLoss +from Supertagger import * from utils import pad_sequence @@ -38,7 +36,7 @@ def output_create_dir(): class Linker(Module): - def __init__(self, supertagger): + def __init__(self, supertagger_path_model): super(Linker, self).__init__() self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder']) @@ -54,6 +52,8 @@ class Linker(Module): self.dropout = Dropout(0.1) self.device = "cpu" + supertagger = SuperTagger() + supertagger.load_weights(supertagger_path_model) self.Supertagger = supertagger self.atom_map = atom_map diff --git a/train.py b/train.py index 77e89b21a82e4ffd8483ed8295c03b387bdedd4b..8b8ff0d5eed0e9ed2601a048c020372df5530a52 100644 --- a/train.py +++ b/train.py @@ -14,12 +14,8 @@ file_path_axiom_links = 'Datasets/gold_dataset_links.csv' df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) sentences_batch = df_axiom_links["Sentences"].tolist() -supertagger = SuperTagger() -supertagger.load_weights("models/flaubert_super_98%_V2_50e.pt") - print("Linker") -linker = Linker(supertagger) -linker = linker.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) +linker = Linker("models/model_supertagger.pt") print("Linker Training") -linker.train_linker(df_axiom_links, batch_size=batch_size, checkpoint=False, tensorboard=True) +linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, tensorboard=True)