From 74b5b56c6cf0e79231bfbe1bbd3556407c0df60a Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Thu, 19 May 2022 11:27:58 +0200
Subject: [PATCH] change supertagger

---
 Linker/Linker.py | 16 ++++++++--------
 train.py         |  8 ++------
 2 files changed, 10 insertions(+), 14 deletions(-)

diff --git a/Linker/Linker.py b/Linker/Linker.py
index b4a5c80..a2c677b 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -1,12 +1,9 @@
 import os
+import sys
 from datetime import datetime
 
-import torch
-from torch.nn import Sequential, LayerNorm, Dropout
-from torch.nn import Module
 import torch.nn.functional as F
-import sys
-
+from torch.nn import Sequential, LayerNorm, Dropout
 from torch.optim import AdamW
 from torch.utils.data import TensorDataset, random_split
 from torch.utils.tensorboard import SummaryWriter
@@ -16,11 +13,12 @@ from Configuration import Configuration
 from Linker.AtomEmbedding import AtomEmbedding
 from Linker.AtomTokenizer import AtomTokenizer
 from Linker.MHA import AttentionDecoderLayer
-from Linker.atom_map import atom_map
 from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
+from Linker.atom_map import atom_map
+from Linker.eval import mesure_accuracy, SinkhornLoss
 from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \
     get_neg_encoding_for_s_idx
-from Linker.eval import mesure_accuracy, SinkhornLoss
+from Supertagger import *
 from utils import pad_sequence
 
 
@@ -38,7 +36,7 @@ def output_create_dir():
 
 
 class Linker(Module):
-    def __init__(self, supertagger):
+    def __init__(self, supertagger_path_model):
         super(Linker, self).__init__()
 
         self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
@@ -54,6 +52,8 @@ class Linker(Module):
         self.dropout = Dropout(0.1)
         self.device = "cpu"
 
+        supertagger = SuperTagger()
+        supertagger.load_weights(supertagger_path_model)
         self.Supertagger = supertagger
 
         self.atom_map = atom_map
diff --git a/train.py b/train.py
index 513e384..43e237b 100644
--- a/train.py
+++ b/train.py
@@ -13,12 +13,8 @@ file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
 df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
 
 sentences_batch = df_axiom_links["Sentences"].tolist()
-supertagger = SuperTagger()
-supertagger.load_weights("models/model_supertagger.pt")
-sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
 
 print("Linker")
-linker = Linker(supertagger)
-linker = linker.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
+linker = Linker("models/model_supertagger.pt")
 print("Linker Training")
-linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True)
+linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, tensorboard=True)
-- 
GitLab