Skip to content
Snippets Groups Projects
Commit 74b5b56c authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

change supertagger

parent 1b94de2d
Branches
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
import os
import sys
from datetime import datetime
import torch
from torch.nn import Sequential, LayerNorm, Dropout
from torch.nn import Module
import torch.nn.functional as F
import sys
from torch.nn import Sequential, LayerNorm, Dropout
from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter
......@@ -16,11 +13,12 @@ from Configuration import Configuration
from Linker.AtomEmbedding import AtomEmbedding
from Linker.AtomTokenizer import AtomTokenizer
from Linker.MHA import AttentionDecoderLayer
from Linker.atom_map import atom_map
from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from Linker.atom_map import atom_map
from Linker.eval import mesure_accuracy, SinkhornLoss
from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \
get_neg_encoding_for_s_idx
from Linker.eval import mesure_accuracy, SinkhornLoss
from Supertagger import *
from utils import pad_sequence
......@@ -38,7 +36,7 @@ def output_create_dir():
class Linker(Module):
def __init__(self, supertagger):
def __init__(self, supertagger_path_model):
super(Linker, self).__init__()
self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
......@@ -54,6 +52,8 @@ class Linker(Module):
self.dropout = Dropout(0.1)
self.device = "cpu"
supertagger = SuperTagger()
supertagger.load_weights(supertagger_path_model)
self.Supertagger = supertagger
self.atom_map = atom_map
......
......@@ -13,12 +13,8 @@ file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
sentences_batch = df_axiom_links["Sentences"].tolist()
supertagger = SuperTagger()
supertagger.load_weights("models/model_supertagger.pt")
sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
print("Linker")
linker = Linker(supertagger)
linker = linker.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
linker = Linker("models/model_supertagger.pt")
print("Linker Training")
linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True)
linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, tensorboard=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment