From 74b5b56c6cf0e79231bfbe1bbd3556407c0df60a Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Thu, 19 May 2022 11:27:58 +0200 Subject: [PATCH] change supertagger --- Linker/Linker.py | 16 ++++++++-------- train.py | 8 ++------ 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/Linker/Linker.py b/Linker/Linker.py index b4a5c80..a2c677b 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -1,12 +1,9 @@ import os +import sys from datetime import datetime -import torch -from torch.nn import Sequential, LayerNorm, Dropout -from torch.nn import Module import torch.nn.functional as F -import sys - +from torch.nn import Sequential, LayerNorm, Dropout from torch.optim import AdamW from torch.utils.data import TensorDataset, random_split from torch.utils.tensorboard import SummaryWriter @@ -16,11 +13,12 @@ from Configuration import Configuration from Linker.AtomEmbedding import AtomEmbedding from Linker.AtomTokenizer import AtomTokenizer from Linker.MHA import AttentionDecoderLayer -from Linker.atom_map import atom_map from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn +from Linker.atom_map import atom_map +from Linker.eval import mesure_accuracy, SinkhornLoss from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \ get_neg_encoding_for_s_idx -from Linker.eval import mesure_accuracy, SinkhornLoss +from Supertagger import * from utils import pad_sequence @@ -38,7 +36,7 @@ def output_create_dir(): class Linker(Module): - def __init__(self, supertagger): + def __init__(self, supertagger_path_model): super(Linker, self).__init__() self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder']) @@ -54,6 +52,8 @@ class Linker(Module): self.dropout = Dropout(0.1) self.device = "cpu" + supertagger = SuperTagger() + supertagger.load_weights(supertagger_path_model) self.Supertagger = supertagger self.atom_map = atom_map diff --git a/train.py b/train.py index 513e384..43e237b 100644 --- a/train.py +++ b/train.py @@ -13,12 +13,8 @@ file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv' df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) sentences_batch = df_axiom_links["Sentences"].tolist() -supertagger = SuperTagger() -supertagger.load_weights("models/model_supertagger.pt") -sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) print("Linker") -linker = Linker(supertagger) -linker = linker.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) +linker = Linker("models/model_supertagger.pt") print("Linker Training") -linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True) +linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, tensorboard=True) -- GitLab