Skip to content
Snippets Groups Projects
Commit 5ebfe402 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

best score 81%

parent d160eeee
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
......@@ -13,8 +13,8 @@ dim_encoder = 768
[MODEL_LINKER]
nhead=8
dim_emb_atom = 256
dim_feedforward_transformer = 512
dim_emb_atom = 512
dim_feedforward_transformer = 768
num_layers=3
dim_cat_inter=768
dim_cat_out=512
......
......@@ -128,7 +128,7 @@ class Linker(Module):
self.cross_entropy_loss = SinkhornLoss()
self.optimizer = AdamW(self.parameters(),
lr=learning_rate)
self.scheduler = StepLR(self.optimizer, step_size=3, gamma=0.5)
self.scheduler = StepLR(self.optimizer, step_size=2, gamma=0.5)
self.to(self.device)
......@@ -257,8 +257,6 @@ class Linker(Module):
if tensorboard:
writer.add_scalars(f'Accuracy', {
'Train': avg_accuracy_train}, epoch_i)
writer.add_scalars(f'Learning rate', {
'learning_rate': self.scheduler.get_last_lr()}, epoch_i)
writer.add_scalars(f'Loss', {
'Train': avg_train_loss}, epoch_i)
if validation_rate > 0.0:
......
......@@ -352,8 +352,7 @@ print(" test for get GOAL ", get_GOAL(10, 30, df_axiom_links))
def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
bool(
re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch)],
max_len=max_atoms_in_one_type // 2, padding_value=-1)
......@@ -364,8 +363,7 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_at
def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
bool(
re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
not atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch)],
max_len=max_atoms_in_one_type // 2, padding_value=-1)
......@@ -383,3 +381,22 @@ print(" test for cut into pos neg on ['dr(0,s,np)', 's']",
False, False]]), 10, 50))
# endregion
# region style Output
def get_output(links_pred, atoms_batch, atoms_polarity):
r"""
Parameters:
links_pred : atom_vocab_size, batch_size, max atoms in one type
atoms_batch : batch_size, max atoms in sentence
atoms_polarity : batch_size, max atoms in sentence
"""
sentences_with_links = []
for s_idx in range(len(atoms_batch)) :
atoms = atoms_batch[s_idx]
polarities = atoms_polarity[s_idx]
# endregion
\ No newline at end of file
scp -r cdepourt@osirim-slurm.irit.fr:projets/deepgrailGPU1/deepgrail_RNN_with_linker/TensorBoard/ /home/cdepourt/Bureau/deepgrail_RNN_with_linker/TensorBoard
rsync -av -e ssh --exclude="__pycache__" --exclude="venv" --exclude=".git" --exclude=".idea" -r /home/cdepourt/Bureau/deepgrail_RNN_with_linker cdepourt@osirim-slurm.irit.fr:projets/deepgrail2
import numpy as np
import torch
from Configuration import Configuration
from Linker import *
from Linker.atom_map import atom_map_redux
from Linker.utils_linker import get_atoms_batch, get_GOAL, get_atoms_links_batch, get_axiom_links
from Supertagger.SuperTagger.SuperTagger import SuperTagger
from utils import read_csv_pgbar
import re
torch.cuda.empty_cache()
batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 800
epochs = int(Configuration.modelTrainingConfig['epoch'])
file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
atoms_batch, atoms_polarity_batch, num_batch = get_GOAL(290, 875, df_axiom_links)
truth_links_batch = get_axiom_links(324, atoms_polarity_batch, df_axiom_links["Y"])
print("max idx for link", torch.max(truth_links_batch))
neg_idx = [[[i for i, x in enumerate(sentence) if
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))
and not atoms_polarity_batch[s_idx][i]]
for s_idx, sentence in enumerate(atoms_batch)]
for atom_type in list(atom_map_redux.keys())]
max_atoms_in_on_type = 0
for atoms_type_batch in neg_idx:
for sentence in atoms_type_batch:
if len(sentence) > max_atoms_in_on_type:
max_atoms_in_on_type = len(sentence)
print("max atoms of one type in one sentence", max_atoms_in_on_type)
atoms_links_batch = get_atoms_links_batch(df_axiom_links["Y"])
max_atoms_in_links = 0
sentence_max = ""
for sentence in atoms_links_batch:
if len(sentence) > max_atoms_in_links:
max_atoms_in_links = len(sentence)
sentence_max = sentence
print("max atoms in links", max_atoms_in_links)
max_atoms_in_sentence = 0
sentence_max = ""
for sentence in atoms_batch:
if len(sentence) > max_atoms_in_sentence:
max_atoms_in_sentence = len(sentence)
sentence_max = sentence
print("max atoms in categories", max_atoms_in_sentence)
supertagger = SuperTagger()
supertagger.load_weights("models/flaubert_super_98_V2_50e.pt")
sentences_batch = df_axiom_links["X"].str.strip().tolist()
sentences_tokens, sentences_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
max_len_sentence = 0
sentence_max = ""
for sentence in sentences_tokens:
if len(sentence) > max_len_sentence:
max_len_sentence = len(sentence)
sentence_max = sentence
print(" max len sentence", max_len_sentence)
import re
import graphviz
import numpy as np
import regex
from Linker.atom_map import atom_map, atom_map_redux
regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
def recursive_linking(links, dot, category, parent_id, word_idx, depth,
polarity, compt_plus, compt_neg):
res = [(category == atom_type) for atom_type in atom_map.keys()]
if True in res:
polarity = not polarity
if polarity:
atoms_idx = compt_plus[category]
compt_plus[category] += 1
else:
idx_neg = compt_neg[category]
compt_neg[category] += 1
atoms_idx = np.where(links[atom_map_redux[category]] == idx_neg)[0][0]
atom_id = category + "_" + str(polarity) + "_" + str(atoms_idx)
dot.node(atom_id, category + " " + str("+" if polarity else "-"))
dot.edge(parent_id, atom_id)
else:
category_id = category + "_" + str(word_idx) + "_" + str(depth)
dot.node(category_id, category + " " + str("+" if polarity else "-"))
dot.edge(parent_id, category_id)
parent_id = category_id
if category.startswith("dr"):
categories_inside = regex.match(regex_categories, category).groups()
categories_inside = [cat for cat in categories_inside if cat is not None]
categories_inside = [categories_inside[0], categories_inside[1]]
polarities_inside = [polarity, not polarity]
# dl / p
elif category.startswith("dl") or category.startswith("p"):
categories_inside = regex.match(regex_categories, category).groups()
categories_inside = [cat for cat in categories_inside if cat is not None]
categories_inside = [categories_inside[0], categories_inside[1]]
polarities_inside = [not polarity, polarity]
# box / dia
elif category.startswith("box") or category.startswith("dia"):
categories_inside = regex.match(regex_categories, category).groups()
categories_inside = [cat for cat in categories_inside if cat is not None]
categories_inside = [categories_inside[0]]
polarities_inside = [polarity]
else:
categories_inside = []
polarities_inside = []
for cat_id in range(len(categories_inside)):
recursive_linking(links, dot, categories_inside[cat_id], parent_id, word_idx, depth + 1, polarities_inside[cat_id], compt_plus,
compt_neg)
def draw_sentence_output(sentence, categories, links):
dot = graphviz.Graph('linking', comment='Axiom linking')
dot.graph_attr['rankdir'] = 'BT'
dot.attr('edge', tailport='n')
dot.attr('edge', headport='s')
compt_plus = {'cl_r': 0, 'pp': 0, 'n': 0, 'np': 0, 'cl_y': 0, 'txt': 0, 's': 0}
compt_neg = {'cl_r': 0, 'pp': 0, 'n': 0, 'np': 0, 'cl_y': 0, 'txt': 0, 's': 0}
for word_idx in range(len(sentence)):
word = sentence[word_idx]
word_id = word + "_" + str(word_idx)
dot.node(word_id, word)
category = categories[word_idx]
polarity = True
parent_id = word_id
recursive_linking(links, dot, category, parent_id, word_idx, 0, polarity, compt_plus, compt_neg)
dot.attr('edge', color='red')
dot.attr('edge', style='dashed')
dot.attr('edge', tailport='n')
dot.attr('edge', headport='n')
for atom_type in list(atom_map_redux.keys()):
for id in range(compt_plus[atom_type]):
atom_plus = atom_type+"_"+str(True)+"_"+str(id)
atom_moins = atom_type+"_"+str(False)+"_"+str(id)
dot.edge(atom_plus, atom_moins, constraint="false")
dot.render(format="svg", view=True)
return dot.source
sentence = ["Le", "chat","est","noir","bleu"]
categories = ["dr(0,s,n)", "dl(0,s,n)","dr(0,dl(0,n,np),n)", "dl(0,np,n)","n"]
links = np.array([[0,0,0,0],[0,0,0,0],[1,0,2,0],[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]])
draw_sentence_output(sentence, categories, links)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment