Skip to content
Snippets Groups Projects
Commit 304c6295 authored by Julien Rabault's avatar Julien Rabault
Browse files

Merge branch 'version-linker' of...

Merge branch 'version-linker' of https://gitlab.irit.fr/pnria/global-helper/deepgrail-linker into version-linker

# Conflicts:
#	train.py
parents d7a164e5 74b5b56c
Branches
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
import os import os
import sys
from datetime import datetime from datetime import datetime
import torch
from torch.nn import Sequential, LayerNorm, Dropout
from torch.nn import Module
import torch.nn.functional as F import torch.nn.functional as F
import sys from torch.nn import Sequential, LayerNorm, Dropout
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split from torch.utils.data import TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
...@@ -16,11 +13,12 @@ from Configuration import Configuration ...@@ -16,11 +13,12 @@ from Configuration import Configuration
from Linker.AtomEmbedding import AtomEmbedding from Linker.AtomEmbedding import AtomEmbedding
from Linker.AtomTokenizer import AtomTokenizer from Linker.AtomTokenizer import AtomTokenizer
from Linker.MHA import AttentionDecoderLayer from Linker.MHA import AttentionDecoderLayer
from Linker.atom_map import atom_map
from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from Linker.atom_map import atom_map
from Linker.eval import mesure_accuracy, SinkhornLoss
from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \ from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \
get_neg_encoding_for_s_idx get_neg_encoding_for_s_idx
from Linker.eval import mesure_accuracy, SinkhornLoss from Supertagger import *
from utils import pad_sequence from utils import pad_sequence
...@@ -38,7 +36,7 @@ def output_create_dir(): ...@@ -38,7 +36,7 @@ def output_create_dir():
class Linker(Module): class Linker(Module):
def __init__(self, supertagger): def __init__(self, supertagger_path_model):
super(Linker, self).__init__() super(Linker, self).__init__()
self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder']) self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
...@@ -54,6 +52,8 @@ class Linker(Module): ...@@ -54,6 +52,8 @@ class Linker(Module):
self.dropout = Dropout(0.1) self.dropout = Dropout(0.1)
self.device = "cpu" self.device = "cpu"
supertagger = SuperTagger()
supertagger.load_weights(supertagger_path_model)
self.Supertagger = supertagger self.Supertagger = supertagger
self.atom_map = atom_map self.atom_map = atom_map
......
...@@ -14,12 +14,8 @@ file_path_axiom_links = 'Datasets/gold_dataset_links.csv' ...@@ -14,12 +14,8 @@ file_path_axiom_links = 'Datasets/gold_dataset_links.csv'
df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
sentences_batch = df_axiom_links["Sentences"].tolist() sentences_batch = df_axiom_links["Sentences"].tolist()
supertagger = SuperTagger()
supertagger.load_weights("models/flaubert_super_98%_V2_50e.pt")
print("Linker") print("Linker")
linker = Linker(supertagger) linker = Linker("models/model_supertagger.pt")
linker = linker.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Linker Training") print("Linker Training")
linker.train_linker(df_axiom_links, batch_size=batch_size, checkpoint=False, tensorboard=True) linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, tensorboard=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment