diff --git a/Configuration/Configuration.py b/Configuration/Configuration.py index 12a4b5f29755681ec929e73496d9bd3921fc9fc4..9a120c779e7430b1d6f4d190eebdc3ed526c85a2 100644 --- a/Configuration/Configuration.py +++ b/Configuration/Configuration.py @@ -10,7 +10,6 @@ config.read(path_config_file) # region Get section version = config["VERSION"] - datasetConfig = config["DATASET_PARAMS"] modelEncoderConfig = config["MODEL_ENCODER"] modelLinkerConfig = config["MODEL_LINKER"] diff --git a/Configuration/config.ini b/Configuration/config.ini index 4de3f49932ff21d1da07523d9d3f63a5ccfc0651..61314f05b83580c954c3650d8147bb61c01ddb16 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -2,29 +2,29 @@ transformers = 4.16.2 [DATASET_PARAMS] -symbols_vocab_size=26 -atom_vocab_size=18 -max_len_sentence=290 -max_atoms_in_sentence=875 -max_atoms_in_one_type=324 +symbols_vocab_size = 26 +atom_vocab_size = 18 +max_len_sentence = 83 +max_atoms_in_sentence = 238 +max_atoms_in_one_type = 102 [MODEL_ENCODER] dim_encoder = 768 [MODEL_LINKER] -nhead=8 -dim_emb_atom = 512 -dim_feedforward_transformer = 768 -num_layers=3 -dim_cat_inter=768 -dim_cat_out=512 -dim_intermediate_FFN=256 -dim_pre_sinkhorn_transfo=32 -dropout=0.1 -sinkhorn_iters=5 +nhead = 8 +dim_emb_atom = 256 +dim_feedforward_transformer = 512 +num_layers = 3 +dim_cat_out = 512 +dim_intermediate_ffn = 256 +dim_pre_sinkhorn_transfo = 32 +dropout = 0.1 +sinkhorn_iters = 5 [MODEL_TRAINING] -batch_size=32 -epoch=30 -seed_val=42 -learning_rate=2e-3 \ No newline at end of file +batch_size = 32 +epoch = 30 +seed_val = 42 +learning_rate = 2e-3 + diff --git a/Linker/Linker.py b/Linker/Linker.py index 498a828cdfb07d15af96084f589ab311a18deed2..2f1084428e01b600e8c7baa9df58a045bdabac35 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -1,9 +1,7 @@ +import datetime import math import os -import re import sys -import datetime - import time import torch @@ -17,17 +15,16 @@ from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from Configuration import Configuration +from Linker.AtomTokenizer import AtomTokenizer from Linker.PositionalEncoding import PositionalEncoding from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn -from Linker.AtomTokenizer import AtomTokenizer from Linker.atom_map import atom_map, atom_map_redux from Linker.eval import mesure_accuracy, SinkhornLoss -from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch, get_neg_idx -from Supertagger import SuperTagger +from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_neg_idx, get_atoms_batch, \ + find_pos_neg_idexes, get_num_atoms_batch +from SuperTagger import SuperTagger from utils import pad_sequence -import torch - def format_time(elapsed): ''' @@ -73,7 +70,6 @@ class Linker(Module): self.num_layers = int(Configuration.modelLinkerConfig['num_layers']) # torch cat dropout = float(Configuration.modelLinkerConfig['dropout']) - self.dim_cat_inter = int(Configuration.modelLinkerConfig['dim_cat_inter']) self.dim_cat_out = int(Configuration.modelLinkerConfig['dim_cat_out']) dim_intermediate_FFN = int(Configuration.modelLinkerConfig['dim_intermediate_FFN']) dim_pre_sinkhorn_transfo = int(Configuration.modelLinkerConfig['dim_pre_sinkhorn_transfo']) @@ -87,7 +83,7 @@ class Linker(Module): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # endregion - # Supertagger for categories + # SuperTagger for categories supertagger = SuperTagger() supertagger.load_weights(supertagger_path_model) self.Supertagger = supertagger @@ -145,11 +141,14 @@ class Linker(Module): sentences_batch = df_axiom_links["X"].str.strip().tolist() sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) - atoms_batch, atoms_polarity_batch, num_atoms_per_word = get_GOAL(self.max_len_sentence, self.max_atoms_in_sentence, df_axiom_links) + atoms_batch, polarities, num_atoms_per_word = get_GOAL(self.max_len_sentence, df_axiom_links) + atoms_polarity_batch = pad_sequence( + [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], + max_len=self.max_atoms_in_sentence, padding_value=0) atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) - pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) - neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) + pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type) + neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type) truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch, df_axiom_links["Y"]) @@ -203,8 +202,8 @@ class Linker(Module): atoms_encoding = self.linker_encoder(atoms_sentences_encoding) # linking per atom type - batch_size, atom_vocan_size, _ = batch_pos_idx.shape - link_weights = torch.zeros(atom_vocan_size, batch_size, self.max_atoms_in_one_type // 2, + batch_size, atom_vocab_size, _ = batch_pos_idx.shape + link_weights = torch.zeros(atom_vocab_size, batch_size, self.max_atoms_in_one_type // 2, self.max_atoms_in_one_type // 2, device=self.device) for atom_type in list(atom_map_redux.keys()): pos_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_pos_idx, atom_type) @@ -252,7 +251,7 @@ class Linker(Module): if checkpoint: self.__checkpoint_save( - path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt')) + path=os.path.join("Output", 'linker' + datetime.datetime.today().strftime('%d-%m_%H-%M') + '.pt')) if tensorboard: writer.add_scalars(f'Accuracy', { @@ -319,7 +318,6 @@ class Linker(Module): accuracy_train += mesure_accuracy(batch_true_links, pred_axiom_links, self.max_atoms_in_one_type) self.scheduler.step() - print("learning rate ", self.scheduler.get_last_lr()) # Measure how long this epoch took. training_time = format_time(time.time() - t0) @@ -370,19 +368,51 @@ class Linker(Module): return loss_average / len(dataloader), accuracy_average / len(dataloader) + def predict(self, sentence, categories): + r""" Predict the links from a sentence and its categories + + Args : + sentence : list of words composing the sentence + categories : list of categories (tags) of each word + """ + self.eval() + with torch.no_grad(): + sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors([sentence]) + sentences_tokens = sentences_tokens.to(self.device) + nb_sentence, len_sentence = sentences_tokens.shape + sentences_mask = sentences_mask.to(self.device) + + atoms = get_atoms_batch([categories]) + atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms).to(self.device) + + polarities = find_pos_neg_idexes([categories]) + polarities = pad_sequence( + [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], + max_len=self.max_atoms_in_sentence, padding_value=0).to(self.device) + + num_atoms_per_word = get_num_atoms_batch([categories], len_sentence).to(self.device) + + pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type).to(self.device) + neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type).to(self.device) + + output = self.Supertagger.forward(sentences_tokens, sentences_mask) + + logits_predictions = self(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embeding']) + axiom_links_pred = torch.argmax(logits_predictions, dim=3) + + return axiom_links_pred + def load_weights(self, model_file): print("#" * 15) try: params = torch.load(model_file, map_location=self.device) - args = params['args'] - self.max_atoms_in_sentence = args['max_atoms_in_sentence'] - self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) self.atom_encoder.load_state_dict(params['atom_encoder']) self.position_encoder.load_state_dict(params['position_encoder']) self.transformer.load_state_dict(params['transformer']) self.linker_encoder.load_state_dict(params['linker_encoder']) self.pos_transformation.load_state_dict(params['pos_transformation']) self.neg_transformation.load_state_dict(params['neg_transformation']) + self.cross_entropy_loss.load_state_dict(params['cross_entropy_loss']) self.optimizer.load_state_dict(params['optimizer']) print("\n The loading checkpoint was successful ! \n") except Exception as e: @@ -399,10 +429,11 @@ class Linker(Module): torch.save({ 'atom_encoder': self.atom_encoder.state_dict(), 'position_encoder': self.position_encoder, - 'transformer': self.transformer, + 'transformer': self.transformer.state_dict(), 'linker_encoder': self.linker_encoder.state_dict(), 'pos_transformation': self.pos_transformation.state_dict(), 'neg_transformation': self.neg_transformation.state_dict(), + 'cross_entropy_loss': self.cross_entropy_loss, 'optimizer': self.optimizer, }, path) self.to(self.device) diff --git a/Linker/README.md b/Linker/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d6903fa8d28b1c64f2e798008332b7eacecc8786 --- /dev/null +++ b/Linker/README.md @@ -0,0 +1,50 @@ +# DeepGrail Linker + +This repository contains a Python implementation of a Neural Proof Net using TLGbank data. + +This code was designed to work with the [DeepGrail Tagger](https://gitlab.irit.fr/pnria/global-helper/deepgrail_tagger). +In this repository we only use the embedding of the word from the tagger and the tags from the dataset, but next step is to use the prediction of the tagger for the linking step. + +## Usage + +### Installation +Python 3.9.10 **(Warning don't use Python 3.10**+**)** +Clone the project locally. + +### Libraries installation + +In a clean python venv do `pip install -r requirements.txt` + +### Dataset format + +The sentences should be in a column "X", the links with '_x' postfix should be in a column "Y" and the categories in a column "Z". +For the links each atom_x goes with the one and only other atom_x in the sentence. + +## Training + +Launch train.py, if you look at it you can give another dataset file and another tagging model. + +In train, if you use `checkpoint=True`, the model is automatically saved in a folder: Training_XX-XX_XX-XX. It saves +after each epoch. Use `tensorboard=True` for log in same folder. (`tensorboard --logdir=logs` for see logs) + +## Predicting + +For predict on your data you need to load a model (save with this code). + +``` +df = read_csv_pgbar(file_path,20) +texts = df['X'].tolist() +categories = df['Z'].tolist() + +linker = Linker(tagging_model) +linker.load_weights("your/linker/path") + +links = linker.predict(texts[7], categories[7]) +print(links) +``` + +The file ```postprocessing.py``` will allow you to draw the prediction. (limited sentence length otherwise it will be confusing) + +## Authors + +[de Pourtales Caroline](https://www.linkedin.com/in/caroline-de-pourtales/), [Rabault Julien](https://www.linkedin.com/in/julienrabault) \ No newline at end of file diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py index 3b38a77f92e1ffd85e8b76af626d5cbedb843295..15b37f39d0abfda2f09b50e26de7b744ba0796b9 100644 --- a/Linker/utils_linker.py +++ b/Linker/utils_linker.py @@ -318,7 +318,7 @@ print(" test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', ' # region get atoms and polarities with GOAL -def get_GOAL(max_len_sentence, max_atoms_in_sentence, df_axiom_links): +def get_GOAL(max_len_sentence, df_axiom_links): categories_batch = df_axiom_links["Z"] categories_with_goal = df_axiom_links["Y"] polarities = find_pos_neg_idexes(categories_batch) @@ -334,8 +334,7 @@ def get_GOAL(max_len_sentence, max_atoms_in_sentence, df_axiom_links): polarities[s_idx] = polarities_goal + polarities[s_idx] # + False num_atoms_batch[s_idx][0] += len(atoms) # +1 - return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], - max_len=max_atoms_in_sentence, padding_value=0), num_atoms_batch + return atoms_batch, polarities, num_atoms_batch df_axiom_links = pd.DataFrame({"Z": [['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', @@ -343,14 +342,14 @@ df_axiom_links = pd.DataFrame({"Z": [['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', "Y": [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']]}) -print(" test for get GOAL ", get_GOAL(10, 30, df_axiom_links)) +print(" test for get GOAL ", get_GOAL(10, df_axiom_links)) # endregion # region get idx for pos and neg -def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence): +def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type): pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and atoms_polarity_batch[s_idx][i]]) @@ -361,7 +360,7 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_at return torch.stack(pos_idx).permute(1, 0, 2) -def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence): +def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type): pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and not atoms_polarity_batch[s_idx][i]]) @@ -378,25 +377,6 @@ print(" test for cut into pos neg on ['dr(0,s,np)', 's']", [[True, True, False, False, True, False, False, False, False, False, - False, False]]), 10, 50)) - -# endregion - - -# region style Output - -def get_output(links_pred, atoms_batch, atoms_polarity): - r""" - Parameters: - links_pred : atom_vocab_size, batch_size, max atoms in one type - atoms_batch : batch_size, max atoms in sentence - atoms_polarity : batch_size, max atoms in sentence - """ - sentences_with_links = [] - for s_idx in range(len(atoms_batch)) : - atoms = atoms_batch[s_idx] - polarities = atoms_polarity[s_idx] - - + False, False]]), 10)) # endregion \ No newline at end of file diff --git a/README.md b/README.md index 5994f1455440e7055fec3c5dd2f7e9baaa7e0cd5..15a86160d5bcbc79b6612d9712e0bef9dbc85ce1 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,7 @@ # DeepGrail -## Usage +This repository contains a Python implementation of a Neural Proof Net using TLGbank data. -### Installation -Python 3.9.10 **(Warning don't use Python 3.10**+**)** - -Clone the project locally. In a clean python venv do `pip install -r requirements.txt` - -## How To use - -TODO ... - -tensorboard --logdir=logs +## Authors +[de Pourtales Caroline](https://www.linkedin.com/in/caroline-de-pourtales/), [Rabault Julien](https://www.linkedin.com/in/julienrabault) \ No newline at end of file diff --git a/find_config.py b/find_config.py index 58d95bdf679280d800596787f11056df42fc5a72..53725288c4f3c1c2b29c7fe2f8d473fb0d0a4f4a 100644 --- a/find_config.py +++ b/find_config.py @@ -1,63 +1,61 @@ -import numpy as np -import torch -from Configuration import Configuration -from Linker import * -from Linker.atom_map import atom_map_redux -from Linker.utils_linker import get_atoms_batch, get_GOAL, get_atoms_links_batch, get_axiom_links -from Supertagger.SuperTagger.SuperTagger import SuperTagger -from utils import read_csv_pgbar +import configparser import re +import torch -torch.cuda.empty_cache() -batch_size = int(Configuration.modelTrainingConfig['batch_size']) -nb_sentences = batch_size * 800 -epochs = int(Configuration.modelTrainingConfig['epoch']) -file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv' -df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) - -atoms_batch, atoms_polarity_batch, num_batch = get_GOAL(290, 875, df_axiom_links) - -truth_links_batch = get_axiom_links(324, atoms_polarity_batch, df_axiom_links["Y"]) -print("max idx for link", torch.max(truth_links_batch)) - -neg_idx = [[[i for i, x in enumerate(sentence) if - bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) - and not atoms_polarity_batch[s_idx][i]] - for s_idx, sentence in enumerate(atoms_batch)] - for atom_type in list(atom_map_redux.keys())] -max_atoms_in_on_type = 0 -for atoms_type_batch in neg_idx: - for sentence in atoms_type_batch: - if len(sentence) > max_atoms_in_on_type: - max_atoms_in_on_type = len(sentence) -print("max atoms of one type in one sentence", max_atoms_in_on_type) - -atoms_links_batch = get_atoms_links_batch(df_axiom_links["Y"]) -max_atoms_in_links = 0 -sentence_max = "" -for sentence in atoms_links_batch: - if len(sentence) > max_atoms_in_links: - max_atoms_in_links = len(sentence) - sentence_max = sentence -print("max atoms in links", max_atoms_in_links) - -max_atoms_in_sentence = 0 -sentence_max = "" -for sentence in atoms_batch: - if len(sentence) > max_atoms_in_sentence: - max_atoms_in_sentence = len(sentence) - sentence_max = sentence -print("max atoms in categories", max_atoms_in_sentence) - -supertagger = SuperTagger() -supertagger.load_weights("models/flaubert_super_98_V2_50e.pt") -sentences_batch = df_axiom_links["X"].str.strip().tolist() -sentences_tokens, sentences_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) -max_len_sentence = 0 -sentence_max = "" -for sentence in sentences_tokens: - if len(sentence) > max_len_sentence: - max_len_sentence = len(sentence) - sentence_max = sentence -print(" max len sentence", max_len_sentence) +from Linker.atom_map import atom_map_redux +from Linker.utils_linker import get_GOAL, get_atoms_links_batch, get_atoms_batch +from SuperTagger.SuperTagger.SuperTagger import SuperTagger +from utils import read_csv_pgbar, pad_sequence + + +def configurate(dataset, model_tagger, nb_sentences=1000000000): + print("#" * 20) + print("#" * 20) + print("Configuration with dataset\n") + config = configparser.ConfigParser() + config.read('Configuration/config.ini') + + file_path_axiom_links = dataset + df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) + + supertagger = SuperTagger() + supertagger.load_weights(model_tagger) + sentences_batch = df_axiom_links["X"].str.strip().tolist() + sentences_tokens, sentences_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) + max_len_sentence = 0 + for sentence in sentences_tokens: + if len(sentence) > max_len_sentence: + max_len_sentence = len(sentence) + print("Configure parameter max len sentence to ", max_len_sentence) + config.set('DATASET_PARAMS', 'max_len_sentence', str(max_len_sentence)) + + atoms_batch, polarities, num_batch = get_GOAL(max_len_sentence, df_axiom_links) + max_atoms_in_sentence = 0 + for sentence in atoms_batch: + if len(sentence) > max_atoms_in_sentence: + max_atoms_in_sentence = len(sentence) + print("Configure parameter max atoms in categories to", max_atoms_in_sentence) + config.set('DATASET_PARAMS', 'max_atoms_in_sentence', str(max_atoms_in_sentence)) + + atoms_polarity_batch = pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], + max_len=max_atoms_in_sentence, padding_value=0) + pos_idx = [[torch.as_tensor([i for i, x in enumerate(sentence) if + bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) + and atoms_polarity_batch[s_idx][i]]) + for s_idx, sentence in enumerate(atoms_batch)] + for atom_type in list(atom_map_redux.keys())] + max_atoms_in_on_type = 0 + for atoms_type_batch in pos_idx: + for sentence in atoms_type_batch: + length = sentence.size(0) + if length > max_atoms_in_on_type: + max_atoms_in_on_type = length + print("Configure parameter max atoms of one type in one sentence to", max_atoms_in_on_type) + config.set('DATASET_PARAMS', 'max_atoms_in_one_type', str(max_atoms_in_on_type * 2+2)) + + with open('Configuration/config.ini', 'w') as configfile: # save + config.write(configfile) + + print("#" * 20) + print("#" * 20) \ No newline at end of file diff --git a/postprocessing.py b/postprocessing.py index 4dbf2007d546153e1025229c10c535282aa74339..d2d43f03a99f877de7391c9e62bd91aea2996e30 100644 --- a/postprocessing.py +++ b/postprocessing.py @@ -10,6 +10,19 @@ regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)' def recursive_linking(links, dot, category, parent_id, word_idx, depth, polarity, compt_plus, compt_neg): + r""" + recursive linking between atoms inside a category + :param links: + :param dot: + :param category: + :param parent_id: + :param word_idx: + :param depth: + :param polarity: + :param compt_plus: + :param compt_neg: + :return: + """ res = [(category == atom_type) for atom_type in atom_map.keys()] if True in res: polarity = not polarity @@ -54,43 +67,53 @@ def recursive_linking(links, dot, category, parent_id, word_idx, depth, polarities_inside = [] for cat_id in range(len(categories_inside)): - recursive_linking(links, dot, categories_inside[cat_id], parent_id, word_idx, depth + 1, polarities_inside[cat_id], compt_plus, - compt_neg) + recursive_linking(links, dot, categories_inside[cat_id], parent_id, word_idx, depth + 1, + polarities_inside[cat_id], compt_plus, + compt_neg) def draw_sentence_output(sentence, categories, links): + r""" + Drawing the prediction of a sentence when given categories and links predictions + :param sentence: list of words + :param categories: list of categories + :param links: links predicted + :return: dot source + """ dot = graphviz.Graph('linking', comment='Axiom linking') dot.graph_attr['rankdir'] = 'BT' - dot.attr('edge', tailport='n') - dot.attr('edge', headport='s') + dot.graph_attr['splines'] = 'ortho' + dot.graph_attr['ordering'] = 'in' compt_plus = {'cl_r': 0, 'pp': 0, 'n': 0, 'np': 0, 'cl_y': 0, 'txt': 0, 's': 0} compt_neg = {'cl_r': 0, 'pp': 0, 'n': 0, 'np': 0, 'cl_y': 0, 'txt': 0, 's': 0} + last_word_id = "" for word_idx in range(len(sentence)): word = sentence[word_idx] word_id = word + "_" + str(word_idx) dot.node(word_id, word) + if word_idx > 0: + dot.edge(last_word_id, word_id, constraint="false", style="invis") category = categories[word_idx] polarity = True parent_id = word_id recursive_linking(links, dot, category, parent_id, word_idx, 0, polarity, compt_plus, compt_neg) + last_word_id = word_id dot.attr('edge', color='red') dot.attr('edge', style='dashed') - dot.attr('edge', tailport='n') - dot.attr('edge', headport='n') for atom_type in list(atom_map_redux.keys()): for id in range(compt_plus[atom_type]): - atom_plus = atom_type+"_"+str(True)+"_"+str(id) - atom_moins = atom_type+"_"+str(False)+"_"+str(id) + atom_plus = atom_type + "_" + str(True) + "_" + str(id) + atom_moins = atom_type + "_" + str(False) + "_" + str(id) dot.edge(atom_plus, atom_moins, constraint="false") dot.render(format="svg", view=True) return dot.source -sentence = ["Le", "chat","est","noir","bleu"] -categories = ["dr(0,s,n)", "dl(0,s,n)","dr(0,dl(0,n,np),n)", "dl(0,np,n)","n"] -links = np.array([[0,0,0,0],[0,0,0,0],[1,0,2,0],[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]]) -draw_sentence_output(sentence, categories, links) \ No newline at end of file +sentence = ["Le", "chat", "est", "noir", "bleu"] +categories = ["dr(0,s,n)", "dl(0,s,n)", "dr(0,dl(0,n,np),n)", "dl(0,np,n)", "n"] +links = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 2, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) +draw_sentence_output(sentence, categories, links) diff --git a/train.py b/train.py index fdf3936a593eaf3ceecc042167694b57caec06d2..0f1d17c6abd722b78f4060d1f1f9433314950aac 100644 --- a/train.py +++ b/train.py @@ -2,16 +2,21 @@ import torch from Configuration import Configuration from Linker import * from utils import read_csv_pgbar +from find_config import configurate torch.cuda.empty_cache() batch_size = int(Configuration.modelTrainingConfig['batch_size']) -nb_sentences = batch_size * 800 -epochs = int(Configuration.modelTrainingConfig['epoch']) - +nb_sentences = batch_size * 4 file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv' +model_tagger = "models/flaubert_super_98_V2_50e.pt" +configurate(file_path_axiom_links, model_tagger, nb_sentences=nb_sentences) + +epochs = int(Configuration.modelTrainingConfig['epoch']) df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) print("Linker") -linker = Linker("models/flaubert_super_98_V2_50e.pt") +# Load the Linker with trained tagger +linker = Linker(model_tagger) print("\nLinker Training\n") -linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=False, tensorboard=True) \ No newline at end of file +linker.train_linker(df_axiom_links, validation_rate=0.05, epochs=1, batch_size=batch_size, + checkpoint=True, tensorboard=True) diff --git a/utils.py b/utils.py index 0433510b2838731d38fd2e42e16a4a7b94ecf3b8..c4fae14e45ecebee5077044687bd9db9a5280936 100644 --- a/utils.py +++ b/utils.py @@ -6,6 +6,14 @@ from tqdm import tqdm def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400): + r""" + Padding sequence for preparation to tensorDataset + :param sequences: data to pad + :param batch_first: boolean indicating whether the batch are in first dimension + :param padding_value: the value for pad + :param max_len: the maximum length + :return: padding sequences + """ max_size = sequences[0].size() trailing_dims = max_size[1:] if batch_first: @@ -26,7 +34,13 @@ def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400): def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500): - print("\n" + "#" * 20) + r""" + Preparing csv dataset + :param csv_path: + :param nrows: + :param chunksize: + :return: + """ print("Loading csv...") rows = sum(1 for _ in open(csv_path, 'r', encoding="utf8")) - 1 # minus the header @@ -42,7 +56,6 @@ def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500): bar.update(len(chunk)) df = pd.concat((f for f in chunk_list), axis=0) - print("#" * 20) return df