Skip to content
Snippets Groups Projects
Commit 74b5b56c authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

change supertagger

parent 1b94de2d
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
import os import os
import sys
from datetime import datetime from datetime import datetime
import torch
from torch.nn import Sequential, LayerNorm, Dropout
from torch.nn import Module
import torch.nn.functional as F import torch.nn.functional as F
import sys from torch.nn import Sequential, LayerNorm, Dropout
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split from torch.utils.data import TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
...@@ -16,11 +13,12 @@ from Configuration import Configuration ...@@ -16,11 +13,12 @@ from Configuration import Configuration
from Linker.AtomEmbedding import AtomEmbedding from Linker.AtomEmbedding import AtomEmbedding
from Linker.AtomTokenizer import AtomTokenizer from Linker.AtomTokenizer import AtomTokenizer
from Linker.MHA import AttentionDecoderLayer from Linker.MHA import AttentionDecoderLayer
from Linker.atom_map import atom_map
from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from Linker.atom_map import atom_map
from Linker.eval import mesure_accuracy, SinkhornLoss
from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \ from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \
get_neg_encoding_for_s_idx get_neg_encoding_for_s_idx
from Linker.eval import mesure_accuracy, SinkhornLoss from Supertagger import *
from utils import pad_sequence from utils import pad_sequence
...@@ -38,7 +36,7 @@ def output_create_dir(): ...@@ -38,7 +36,7 @@ def output_create_dir():
class Linker(Module): class Linker(Module):
def __init__(self, supertagger): def __init__(self, supertagger_path_model):
super(Linker, self).__init__() super(Linker, self).__init__()
self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder']) self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
...@@ -54,6 +52,8 @@ class Linker(Module): ...@@ -54,6 +52,8 @@ class Linker(Module):
self.dropout = Dropout(0.1) self.dropout = Dropout(0.1)
self.device = "cpu" self.device = "cpu"
supertagger = SuperTagger()
supertagger.load_weights(supertagger_path_model)
self.Supertagger = supertagger self.Supertagger = supertagger
self.atom_map = atom_map self.atom_map = atom_map
......
...@@ -13,12 +13,8 @@ file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv' ...@@ -13,12 +13,8 @@ file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
sentences_batch = df_axiom_links["Sentences"].tolist() sentences_batch = df_axiom_links["Sentences"].tolist()
supertagger = SuperTagger()
supertagger.load_weights("models/model_supertagger.pt")
sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
print("Linker") print("Linker")
linker = Linker(supertagger) linker = Linker("models/model_supertagger.pt")
linker = linker.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Linker Training") print("Linker Training")
linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True) linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, tensorboard=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment