diff --git a/Configuration/config.ini b/Configuration/config.ini index 0cf354cabac8903a035728343ccaba2acdae9ed9..4de3f49932ff21d1da07523d9d3f63a5ccfc0651 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -13,8 +13,8 @@ dim_encoder = 768 [MODEL_LINKER] nhead=8 -dim_emb_atom = 256 -dim_feedforward_transformer = 512 +dim_emb_atom = 512 +dim_feedforward_transformer = 768 num_layers=3 dim_cat_inter=768 dim_cat_out=512 diff --git a/Linker/Linker.py b/Linker/Linker.py index 370b25b0282c01b2c800c2dee08eff00f9b52696..498a828cdfb07d15af96084f589ab311a18deed2 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -128,7 +128,7 @@ class Linker(Module): self.cross_entropy_loss = SinkhornLoss() self.optimizer = AdamW(self.parameters(), lr=learning_rate) - self.scheduler = StepLR(self.optimizer, step_size=3, gamma=0.5) + self.scheduler = StepLR(self.optimizer, step_size=2, gamma=0.5) self.to(self.device) @@ -257,8 +257,6 @@ class Linker(Module): if tensorboard: writer.add_scalars(f'Accuracy', { 'Train': avg_accuracy_train}, epoch_i) - writer.add_scalars(f'Learning rate', { - 'learning_rate': self.scheduler.get_last_lr()}, epoch_i) writer.add_scalars(f'Loss', { 'Train': avg_train_loss}, epoch_i) if validation_rate > 0.0: diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py index 92b41d589d6d8f8b250f3f6882d1c2f92cce046f..3b38a77f92e1ffd85e8b76af626d5cbedb843295 100644 --- a/Linker/utils_linker.py +++ b/Linker/utils_linker.py @@ -352,8 +352,7 @@ print(" test for get GOAL ", get_GOAL(10, 30, df_axiom_links)) def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence): pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if - bool( - re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and + bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch)], max_len=max_atoms_in_one_type // 2, padding_value=-1) @@ -364,8 +363,7 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_at def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence): pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if - bool( - re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and + bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and not atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch)], max_len=max_atoms_in_one_type // 2, padding_value=-1) @@ -383,3 +381,22 @@ print(" test for cut into pos neg on ['dr(0,s,np)', 's']", False, False]]), 10, 50)) # endregion + + +# region style Output + +def get_output(links_pred, atoms_batch, atoms_polarity): + r""" + Parameters: + links_pred : atom_vocab_size, batch_size, max atoms in one type + atoms_batch : batch_size, max atoms in sentence + atoms_polarity : batch_size, max atoms in sentence + """ + sentences_with_links = [] + for s_idx in range(len(atoms_batch)) : + atoms = atoms_batch[s_idx] + polarities = atoms_polarity[s_idx] + + + +# endregion \ No newline at end of file diff --git a/command_line.txt b/command_line.txt deleted file mode 100644 index 31b0fc48391da24d8765df9b88505d761357daa3..0000000000000000000000000000000000000000 --- a/command_line.txt +++ /dev/null @@ -1,4 +0,0 @@ -scp -r cdepourt@osirim-slurm.irit.fr:projets/deepgrailGPU1/deepgrail_RNN_with_linker/TensorBoard/ /home/cdepourt/Bureau/deepgrail_RNN_with_linker/TensorBoard - -rsync -av -e ssh --exclude="__pycache__" --exclude="venv" --exclude=".git" --exclude=".idea" -r /home/cdepourt/Bureau/deepgrail_RNN_with_linker cdepourt@osirim-slurm.irit.fr:projets/deepgrail2 - diff --git a/find_config.py b/find_config.py new file mode 100644 index 0000000000000000000000000000000000000000..58d95bdf679280d800596787f11056df42fc5a72 --- /dev/null +++ b/find_config.py @@ -0,0 +1,63 @@ +import numpy as np +import torch +from Configuration import Configuration +from Linker import * +from Linker.atom_map import atom_map_redux +from Linker.utils_linker import get_atoms_batch, get_GOAL, get_atoms_links_batch, get_axiom_links +from Supertagger.SuperTagger.SuperTagger import SuperTagger +from utils import read_csv_pgbar +import re + + +torch.cuda.empty_cache() +batch_size = int(Configuration.modelTrainingConfig['batch_size']) +nb_sentences = batch_size * 800 +epochs = int(Configuration.modelTrainingConfig['epoch']) +file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv' +df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) + +atoms_batch, atoms_polarity_batch, num_batch = get_GOAL(290, 875, df_axiom_links) + +truth_links_batch = get_axiom_links(324, atoms_polarity_batch, df_axiom_links["Y"]) +print("max idx for link", torch.max(truth_links_batch)) + +neg_idx = [[[i for i, x in enumerate(sentence) if + bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) + and not atoms_polarity_batch[s_idx][i]] + for s_idx, sentence in enumerate(atoms_batch)] + for atom_type in list(atom_map_redux.keys())] +max_atoms_in_on_type = 0 +for atoms_type_batch in neg_idx: + for sentence in atoms_type_batch: + if len(sentence) > max_atoms_in_on_type: + max_atoms_in_on_type = len(sentence) +print("max atoms of one type in one sentence", max_atoms_in_on_type) + +atoms_links_batch = get_atoms_links_batch(df_axiom_links["Y"]) +max_atoms_in_links = 0 +sentence_max = "" +for sentence in atoms_links_batch: + if len(sentence) > max_atoms_in_links: + max_atoms_in_links = len(sentence) + sentence_max = sentence +print("max atoms in links", max_atoms_in_links) + +max_atoms_in_sentence = 0 +sentence_max = "" +for sentence in atoms_batch: + if len(sentence) > max_atoms_in_sentence: + max_atoms_in_sentence = len(sentence) + sentence_max = sentence +print("max atoms in categories", max_atoms_in_sentence) + +supertagger = SuperTagger() +supertagger.load_weights("models/flaubert_super_98_V2_50e.pt") +sentences_batch = df_axiom_links["X"].str.strip().tolist() +sentences_tokens, sentences_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) +max_len_sentence = 0 +sentence_max = "" +for sentence in sentences_tokens: + if len(sentence) > max_len_sentence: + max_len_sentence = len(sentence) + sentence_max = sentence +print(" max len sentence", max_len_sentence) diff --git a/postprocessing.py b/postprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..4dbf2007d546153e1025229c10c535282aa74339 --- /dev/null +++ b/postprocessing.py @@ -0,0 +1,96 @@ +import re + +import graphviz +import numpy as np +import regex +from Linker.atom_map import atom_map, atom_map_redux + +regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)' + + +def recursive_linking(links, dot, category, parent_id, word_idx, depth, + polarity, compt_plus, compt_neg): + res = [(category == atom_type) for atom_type in atom_map.keys()] + if True in res: + polarity = not polarity + if polarity: + atoms_idx = compt_plus[category] + compt_plus[category] += 1 + else: + idx_neg = compt_neg[category] + compt_neg[category] += 1 + atoms_idx = np.where(links[atom_map_redux[category]] == idx_neg)[0][0] + atom_id = category + "_" + str(polarity) + "_" + str(atoms_idx) + dot.node(atom_id, category + " " + str("+" if polarity else "-")) + dot.edge(parent_id, atom_id) + else: + category_id = category + "_" + str(word_idx) + "_" + str(depth) + dot.node(category_id, category + " " + str("+" if polarity else "-")) + dot.edge(parent_id, category_id) + parent_id = category_id + + if category.startswith("dr"): + categories_inside = regex.match(regex_categories, category).groups() + categories_inside = [cat for cat in categories_inside if cat is not None] + categories_inside = [categories_inside[0], categories_inside[1]] + polarities_inside = [polarity, not polarity] + + # dl / p + elif category.startswith("dl") or category.startswith("p"): + categories_inside = regex.match(regex_categories, category).groups() + categories_inside = [cat for cat in categories_inside if cat is not None] + categories_inside = [categories_inside[0], categories_inside[1]] + polarities_inside = [not polarity, polarity] + + # box / dia + elif category.startswith("box") or category.startswith("dia"): + categories_inside = regex.match(regex_categories, category).groups() + categories_inside = [cat for cat in categories_inside if cat is not None] + categories_inside = [categories_inside[0]] + polarities_inside = [polarity] + + else: + categories_inside = [] + polarities_inside = [] + + for cat_id in range(len(categories_inside)): + recursive_linking(links, dot, categories_inside[cat_id], parent_id, word_idx, depth + 1, polarities_inside[cat_id], compt_plus, + compt_neg) + + +def draw_sentence_output(sentence, categories, links): + dot = graphviz.Graph('linking', comment='Axiom linking') + dot.graph_attr['rankdir'] = 'BT' + dot.attr('edge', tailport='n') + dot.attr('edge', headport='s') + + compt_plus = {'cl_r': 0, 'pp': 0, 'n': 0, 'np': 0, 'cl_y': 0, 'txt': 0, 's': 0} + compt_neg = {'cl_r': 0, 'pp': 0, 'n': 0, 'np': 0, 'cl_y': 0, 'txt': 0, 's': 0} + for word_idx in range(len(sentence)): + word = sentence[word_idx] + word_id = word + "_" + str(word_idx) + dot.node(word_id, word) + + category = categories[word_idx] + polarity = True + parent_id = word_id + recursive_linking(links, dot, category, parent_id, word_idx, 0, polarity, compt_plus, compt_neg) + + dot.attr('edge', color='red') + dot.attr('edge', style='dashed') + dot.attr('edge', tailport='n') + dot.attr('edge', headport='n') + for atom_type in list(atom_map_redux.keys()): + for id in range(compt_plus[atom_type]): + atom_plus = atom_type+"_"+str(True)+"_"+str(id) + atom_moins = atom_type+"_"+str(False)+"_"+str(id) + dot.edge(atom_plus, atom_moins, constraint="false") + + dot.render(format="svg", view=True) + return dot.source + + +sentence = ["Le", "chat","est","noir","bleu"] +categories = ["dr(0,s,n)", "dl(0,s,n)","dr(0,dl(0,n,np),n)", "dl(0,np,n)","n"] +links = np.array([[0,0,0,0],[0,0,0,0],[1,0,2,0],[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]]) +draw_sentence_output(sentence, categories, links) \ No newline at end of file