Skip to content
Snippets Groups Projects
Commit 35345fe6 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

update

parent 17baed42
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
......@@ -6,7 +6,7 @@ symbols_vocab_size=26
atom_vocab_size=20
max_len_sentence=109
max_atoms_in_sentence=1250
max_atoms_in_one_type=250
max_atoms_in_one_type=324
[MODEL_ENCODER]
dim_encoder = 768
......@@ -26,8 +26,7 @@ dropout=0.1
sinkhorn_iters=3
[MODEL_TRAINING]
device=cpu
batch_size=16
epoch=25
epoch=30
seed_val=42
learning_rate=2e-4
......@@ -22,9 +22,10 @@ from Linker.atom_map import atom_map
from Linker.eval import mesure_accuracy, SinkhornLoss
from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \
get_neg_encoding_for_s_idx
from SuperTagger import *
from Supertagger import *
from utils import pad_sequence
def format_time(elapsed):
'''
Takes a time in seconds and returns a string hh:mm:ss
......@@ -35,6 +36,7 @@ def format_time(elapsed):
# Format as hh:mm:ss
return str(datetime.timedelta(seconds=elapsed_rounded))
def output_create_dir():
"""
Create le output dir for tensorboard and checkpoint
......@@ -99,16 +101,17 @@ class Linker(Module):
Returns:
the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask
"""
sentences_batch = df_axiom_links["Sentences"].tolist()
print("Start preprocess Data")
sentences_batch = df_axiom_links["X"].tolist()
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"])
atoms_batch = get_atoms_batch(df_axiom_links["Z"])
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["sub_tree"])
atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["Z"])
truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
df_axiom_links["sub_tree"])
df_axiom_links["Y"])
truth_links_batch = truth_links_batch.permute(1, 0, 2)
# Construction tensor dataset
......@@ -125,6 +128,7 @@ class Linker(Module):
train_dataset = dataset
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
print("End preprocess Data")
return training_dataloader, validation_dataloader
def make_decoder_mask(self, atoms_token):
......@@ -202,7 +206,6 @@ class Linker(Module):
print('Training...')
avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader)
print("")
print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
......@@ -215,7 +218,6 @@ class Linker(Module):
self.__checkpoint_save(
path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
if tensorboard:
writer.add_scalars(f'Accuracy', {
'Train': avg_accuracy_train}, epoch_i)
......
......@@ -23,6 +23,7 @@ class FFN(Module):
return self.ffn(x)
regex_categories_axiom_links = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
......@@ -76,7 +77,7 @@ def category_to_atoms_axiom_links(category, categories_to_atoms):
elif True in res:
return [category]
else:
category_cut = regex.match(regex_categories, category).groups()
category_cut = regex.match(regex_categories_axiom_links, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut:
categories_to_atoms += category_to_atoms_axiom_links(cat, [])
......@@ -111,14 +112,16 @@ def category_to_atoms(category, categories_to_atoms):
Returns:
List of atoms inside the category in prefix order
"""
res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
res = [bool(re.match(r'' + atom_type, category)) for atom_type in atom_map.keys()]
if category.startswith("GOAL:"):
word, cat = category.split(':')
return category_to_atoms(cat, categories_to_atoms)
elif True in res:
category = re.match(r'([a-zA-Z|_]+)_\d+', category).group(1)
return [category]
elif category == "let":
return []
else:
print(category)
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut:
......@@ -155,21 +158,21 @@ def category_to_atoms_polarity(category, polarity):
Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
"""
category_to_polarity = []
res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
res = [bool(re.match(r'' + atom_type, category)) for atom_type in atom_map.keys()]
# mot final
if category.startswith("GOAL:"):
word, cat = category.split(':')
res = [bool(re.match(r'' + atom_type + "_\d+", cat)) for atom_type in atom_map.keys()]
res = [bool(re.match(r'' + atom_type, cat)) for atom_type in atom_map.keys()]
if True in res:
category_to_polarity.append(True)
else:
category_to_polarity += category_to_atoms_polarity(cat, True)
elif category == "let":
pass
# le mot a une category atomique
elif True in res or category.startswith("dia") or category.startswith("box"):
category_to_polarity.append(not polarity)
# sinon c'est une formule longue
else:
# dr = /
......
Supertagger @ eeb4774c
Subproject commit eeb4774c071e03f460f48798ab8d6820395825c9
#!/bin/sh
#SBATCH --job-name=Deepgrail_Linker
#SBATCH --job-name=Deepgrail_Linker_9000
#SBATCH --partition=RTX6000Node
#SBATCH --gres=gpu:1
#SBATCH --mem=32000
......
scp -r cdepourt@osirim-slurm.irit.fr:projets/deepgrail2/deepgrail_RNN_with_linker/TensorBoard/Tranning_19-05_09-49/logs /home/cdepourt/Bureau/deepgrail_RNN_with_linker/logs
rsync -av -e ssh --exclude="__pycache__" --exclude="venv" --exclude=".git" --exclude=".idea" -r /home/cdepourt/Bureau/deepgrail_RNN_with_linker cdepourt@osirim-slurm.irit.fr:projets/deepgrail2
......@@ -37,7 +37,7 @@ def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500):
chunksize = nrows
with tqdm(total=rows, desc='Rows read: ') as bar:
for chunk in pd.read_csv(csv_path, converters={'sub_tree': pd.eval, 'Z': pd.eval}, chunksize=chunksize, nrows=rows):
for chunk in pd.read_csv(csv_path, converters={'Y': pd.eval, 'Z': pd.eval}, chunksize=chunksize, nrows=rows):
chunk_list.append(chunk)
bar.update(len(chunk))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment