Skip to content
Snippets Groups Projects
Commit 501425a2 authored by Julien Rabault's avatar Julien Rabault
Browse files

Add logs

parent ff8dc3af
Branches
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
deepgrail_Tagger SuperTagger
Utils/silver Utils/silver
Utils/gold Utils/gold
.idea .idea
......
...@@ -6,7 +6,6 @@ import time ...@@ -6,7 +6,6 @@ import time
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Module
from torch.nn import Sequential, LayerNorm, Dropout from torch.nn import Sequential, LayerNorm, Dropout
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split from torch.utils.data import TensorDataset, random_split
...@@ -23,7 +22,7 @@ from Linker.atom_map import atom_map ...@@ -23,7 +22,7 @@ from Linker.atom_map import atom_map
from Linker.eval import mesure_accuracy, SinkhornLoss from Linker.eval import mesure_accuracy, SinkhornLoss
from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \ from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \
get_neg_encoding_for_s_idx get_neg_encoding_for_s_idx
from Supertagger import * from SuperTagger import *
from utils import pad_sequence from utils import pad_sequence
def format_time(elapsed): def format_time(elapsed):
......
...@@ -5,13 +5,13 @@ from utils import read_csv_pgbar ...@@ -5,13 +5,13 @@ from utils import read_csv_pgbar
torch.cuda.empty_cache() torch.cuda.empty_cache()
batch_size = int(Configuration.modelTrainingConfig['batch_size']) batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 40 nb_sentences = batch_size * 2
epochs = int(Configuration.modelTrainingConfig['epoch']) epochs = int(Configuration.modelTrainingConfig['epoch'])
file_path_axiom_links = 'Datasets/gold_dataset_links.csv' file_path_axiom_links = 'Datasets/gold_dataset_links.csv'
df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
print("Linker") print("Linker")
linker = Linker("models/model_supertagger.pt") linker = Linker("models/flaubert_super_98%_V2_50e.pt")
print("Linker Training") print("\nLinker Training\n\n")
linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=False, tensorboard=True) linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=False, tensorboard=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment