Skip to content
Snippets Groups Projects
Commit 74b5b56c authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

change supertagger

parent 1b94de2d
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
import os
import sys
from datetime import datetime
import torch
from torch.nn import Sequential, LayerNorm, Dropout
from torch.nn import Module
import torch.nn.functional as F
import sys
from torch.nn import Sequential, LayerNorm, Dropout
from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter
......@@ -16,11 +13,12 @@ from Configuration import Configuration
from Linker.AtomEmbedding import AtomEmbedding
from Linker.AtomTokenizer import AtomTokenizer
from Linker.MHA import AttentionDecoderLayer
from Linker.atom_map import atom_map
from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from Linker.atom_map import atom_map
from Linker.eval import mesure_accuracy, SinkhornLoss
from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \
get_neg_encoding_for_s_idx
from Linker.eval import mesure_accuracy, SinkhornLoss
from Supertagger import *
from utils import pad_sequence
......@@ -38,7 +36,7 @@ def output_create_dir():
class Linker(Module):
def __init__(self, supertagger):
def __init__(self, supertagger_path_model):
super(Linker, self).__init__()
self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
......@@ -54,6 +52,8 @@ class Linker(Module):
self.dropout = Dropout(0.1)
self.device = "cpu"
supertagger = SuperTagger()
supertagger.load_weights(supertagger_path_model)
self.Supertagger = supertagger
self.atom_map = atom_map
......
......@@ -13,12 +13,8 @@ file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
sentences_batch = df_axiom_links["Sentences"].tolist()
supertagger = SuperTagger()
supertagger.load_weights("models/model_supertagger.pt")
sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
print("Linker")
linker = Linker(supertagger)
linker = linker.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
linker = Linker("models/model_supertagger.pt")
print("Linker Training")
linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True)
linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, tensorboard=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment