Skip to content
Snippets Groups Projects
Commit e5bcb0e0 authored by Caroline de Pourtalès's avatar Caroline de Pourtalès :speech_balloon:
Browse files

reduce learning rate linker

parent 843cc70b
Branches main
No related tags found
No related merge requests found
......@@ -4,7 +4,7 @@ transformers = 4.16.2
[DATASET_PARAMS]
symbols_vocab_size = 26
atom_vocab_size = 18
max_len_sentence = 290
max_len_sentence = 157
max_atoms_in_sentence = 440
max_atoms_in_one_type = 180
......
......@@ -88,8 +88,8 @@ class Linker(Module):
# Learning
self.cross_entropy_loss = SinkhornLoss()
self.optimizer = AdamW(self.parameters(), lr=0.001)
self.scheduler = StepLR(self.optimizer, step_size=2, gamma=0.5)
self.optimizer = AdamW(self.parameters(), lr=0.0001)
self.scheduler = StepLR(self.optimizer, step_size=5, gamma=0.5)
self.to(self.device)
def load_weights(self, model_file):
......
......@@ -42,9 +42,8 @@ class NeuralProofNet(Module):
# Learning
self.linker_loss = SinkhornLoss()
self.linker_optimizer = AdamW(self.linker.parameters(),
lr=0.001)
self.linker_scheduler = StepLR(self.linker_optimizer, step_size=2, gamma=0.5)
self.linker_optimizer = AdamW(self.linker.parameters(), lr=0.0001)
self.linker_scheduler = StepLR(self.linker_optimizer, step_size=5, gamma=0.5)
self.to(self.device)
......
import torch
from Linker import *
from NeuralProofNet.NeuralProofNet import NeuralProofNet
from find_config import configurate_linker
from utils import read_links_csv
import torch
torch.cuda.empty_cache()
dataset = 'Datasets/gold_dataset_links.csv'
model_tagger = "models/flaubert_super_98_V2_50e.pt"
# region data
file_path_axiom_links = 'Datasets/gold_dataset_links.csv'
df_axiom_links = read_links_csv(file_path_axiom_links)
# endregion
configurate_linker(dataset, model_tagger, nb_sentences=1000000000)
df_axiom_links = read_links_csv(dataset)
# region model
print("#" * 20)
print("#" * 20)
model_tagger = "models/flaubert_super_98_V2_50e.pt"
neural_proof_net = NeuralProofNet(model_tagger)
neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0.1, epochs=25, pretrain_linker_epochs=25, batch_size=16,
checkpoint=True, tensorboard=True)
print("#" * 20)
print("#" * 20)
# endregion
\ No newline at end of file
checkpoint=True, tensorboard=True)
\ No newline at end of file
......@@ -11,21 +11,17 @@ bert_model = "flaubert/flaubert_base_cased"
configurate_supertagger(dataset, index_to_super_path, bert_model, nb_sentences=1000000000)
# region data
df = read_supertags_csv(dataset)
texts = df['X'].tolist()
tags = df['Z'].tolist()
index_to_super = load_obj(index_to_super_path)
# endregion
# region model
tagger = SuperTagger()
tagger.create_new_model(len(index_to_super),bert_model,index_to_super)
## If you want to upload a pretrained model
# tagger.load_weights("models/model_check.pt")
tagger.train(texts, tags, epochs=70, batch_size=16, validation_rate=0.1,
tensorboard=True, checkpoint=True)
# endregion
......@@ -20,8 +20,12 @@ def read_links_csv(csv_path, nrows=float('inf'), chunksize=100):
print("\n" + "#" * 20)
print("Loading csv...")
rows = sum(1 for _ in open(csv_path, 'r', encoding="utf8")) - 1 # minus the header
chunk_list = []
if rows > nrows:
rows = nrows
with tqdm(total=rows, desc='Rows read: ') as bar:
for chunk in pd.read_csv(csv_path, header=0, converters={'Y': pd.eval, 'Z': pd.eval},
chunksize=chunksize, nrows=nrows):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment