From 501425a2c3ee8aa8ce612a24ad05ebb1a37294d4 Mon Sep 17 00:00:00 2001
From: PNRIA - Julien <julien.rabault@irit.fr>
Date: Thu, 19 May 2022 15:50:25 +0200
Subject: [PATCH] Add logs

---
 .gitignore       | 2 +-
 Linker/Linker.py | 3 +--
 train.py         | 6 +++---
 3 files changed, 5 insertions(+), 6 deletions(-)

diff --git a/.gitignore b/.gitignore
index ac5aa4b..72fcf74 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,4 @@
-deepgrail_Tagger
+SuperTagger
 Utils/silver
 Utils/gold
 .idea
diff --git a/Linker/Linker.py b/Linker/Linker.py
index 816b589..ff37ef8 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -6,7 +6,6 @@ import time
 
 import torch
 import torch.nn.functional as F
-from torch import Module
 from torch.nn import Sequential, LayerNorm, Dropout
 from torch.optim import AdamW
 from torch.utils.data import TensorDataset, random_split
@@ -23,7 +22,7 @@ from Linker.atom_map import atom_map
 from Linker.eval import mesure_accuracy, SinkhornLoss
 from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \
     get_neg_encoding_for_s_idx
-from Supertagger import *
+from SuperTagger import *
 from utils import pad_sequence
 
 def format_time(elapsed):
diff --git a/train.py b/train.py
index 4c6645e..50354a0 100644
--- a/train.py
+++ b/train.py
@@ -5,13 +5,13 @@ from utils import read_csv_pgbar
 
 torch.cuda.empty_cache()
 batch_size = int(Configuration.modelTrainingConfig['batch_size'])
-nb_sentences = batch_size * 40
+nb_sentences = batch_size * 2
 epochs = int(Configuration.modelTrainingConfig['epoch'])
 
 file_path_axiom_links = 'Datasets/gold_dataset_links.csv'
 df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
 
 print("Linker")
-linker = Linker("models/model_supertagger.pt")
-print("Linker Training")
+linker = Linker("models/flaubert_super_98%_V2_50e.pt")
+print("\nLinker Training\n\n")
 linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=False, tensorboard=True)
-- 
GitLab