Skip to content
Snippets Groups Projects
Commit 76abb835 authored by Caroline de Pourtalès's avatar Caroline de Pourtalès :speech_balloon:
Browse files

training

parent 897a4516
Branches main
No related tags found
2 merge requests!3Prepare paper,!2Prepare paper
......@@ -11,9 +11,9 @@ import pandas as pd
def normalize_word(orig_word):
word = orig_word.lower()
if (word is "["):
if (word == "["):
word = "("
if (word is "]"):
if (word == "]"):
word = ")"
return word
......
File moved
import configparser
import re
import torch
from Linker.atom_map import atom_map_redux
from Linker.utils_linker import get_GOAL, get_atoms_links_batch, get_atoms_batch
from SuperTagger.SuperTagger.SuperTagger import SuperTagger
from utils import read_csv_pgbar, pad_sequence
def configurate(dataset, model_tagger, nb_sentences=1000000000):
print("#" * 20)
print("#" * 20)
print("Configuration with dataset\n")
config = configparser.ConfigParser()
config.read('Configuration/config.ini')
file_path_axiom_links = dataset
df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
supertagger = SuperTagger()
supertagger.load_weights(model_tagger)
sentences_batch = df_axiom_links["X"].str.strip().tolist()
sentences_tokens, sentences_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
max_len_sentence = 0
for sentence in sentences_tokens:
if len(sentence) > max_len_sentence:
max_len_sentence = len(sentence)
print("Configure parameter max len sentence to ", max_len_sentence)
config.set('DATASET_PARAMS', 'max_len_sentence', str(max_len_sentence))
atoms_batch, polarities, num_batch = get_GOAL(max_len_sentence, df_axiom_links)
max_atoms_in_sentence = 0
for sentence in atoms_batch:
if len(sentence) > max_atoms_in_sentence:
max_atoms_in_sentence = len(sentence)
print("Configure parameter max atoms in categories to", max_atoms_in_sentence)
config.set('DATASET_PARAMS', 'max_atoms_in_sentence', str(max_atoms_in_sentence))
atoms_polarity_batch = pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
max_len=max_atoms_in_sentence, padding_value=0)
pos_idx = [[torch.as_tensor([i for i, x in enumerate(sentence) if
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))
and atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch)]
for atom_type in list(atom_map_redux.keys())]
max_atoms_in_on_type = 0
for atoms_type_batch in pos_idx:
for sentence in atoms_type_batch:
length = sentence.size(0)
if length > max_atoms_in_on_type:
max_atoms_in_on_type = length
print("Configure parameter max atoms of one type in one sentence to", max_atoms_in_on_type)
config.set('DATASET_PARAMS', 'max_atoms_in_one_type', str(max_atoms_in_on_type * 2+2))
with open('Configuration/config.ini', 'w') as configfile: # save
config.write(configfile)
print("#" * 20)
print("#" * 20)
slurm.sh 0 → 100644
#!/bin/sh
#SBATCH --job-name=Deepgrail_Linker
#SBATCH --partition=RTX6000Node
#SBATCH --gres=gpu:1
#SBATCH --mem=32000
#SBATCH --gres-flags=enforce-binding
#SBATCH --error="error_rtx1.err"
#SBATCH --output="out_rtx1.out"
module purge
module load singularity/3.0.3
srun singularity exec /logiciels/containerCollections/CUDA11/pytorch-NGC-21-03-py3.sif python "train_neuralproofnet.py"
......@@ -16,7 +16,7 @@ print("#" * 20)
print("#" * 20)
model_tagger = "models/flaubert_super_98_V2_50e.pt"
neural_proof_net = NeuralProofNet(model_tagger)
neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0.1, epochs=5, batch_size=16,
neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0.1, epochs=25, pretrain_linker_epochs=20, batch_size=16,
checkpoint=True, tensorboard=True)
print("#" * 20)
print("#" * 20)
......
from SuperTagger.SuperTagger.SuperTagger import SuperTagger
from utils import read_supertags_csv, load_obj
import torch
torch.cuda.empty_cache()
# region data
file_path = 'SuperTagger/Datasets/m2_dataset_V2.csv'
......@@ -16,7 +18,7 @@ tagger = SuperTagger()
tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super)
## If you want to upload a pretrained model
# tagger.load_weights("models/model_check.pt")
tagger.train(texts, tags, epochs=2, batch_size=16, validation_rate=0.1,
tagger.train(texts, tags, epochs=40, batch_size=16, validation_rate=0.1,
tensorboard=True, checkpoint=True)
# endregion
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment