From 7aa67b281e1351f1bbb80a6c1c579a3ceadfb8c8 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Wed, 13 Jul 2022 15:56:44 +0200 Subject: [PATCH] update --- Linker/Linker.py | 9 ++++++++- README.md | 33 +++++++++++++++++++++++++++++++-- postprocessing.py | 2 +- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/Linker/Linker.py b/Linker/Linker.py index 5b86fe8..76596f0 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -380,6 +380,9 @@ class Linker(Module): Args : sentence : list of words composing the sentence categories : list of categories (tags) of each word + + Return : + links : links prediction """ self.eval() with torch.no_grad(): @@ -413,6 +416,10 @@ class Linker(Module): Args : sentence : list of words composing the sentence + + Return : + categories : the supertags predicted + links : links prediction """ self.eval() with torch.no_grad(): @@ -440,7 +447,7 @@ class Linker(Module): logits_predictions = self(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embeding']) axiom_links_pred = torch.argmax(logits_predictions, dim=3) - return axiom_links_pred + return categories, axiom_links_pred def load_weights(self, model_file): print("#" * 15) diff --git a/README.md b/README.md index 97a9213..ce70e16 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,35 @@ Clone the project locally. Run the init.sh script or install the Tagger project under SuperTagger name. And upload the tagger.pt in the directory 'models'. (You may need to modify 'model_tagger' in train.py.) +### Structure + +The structure should look like this : +``` +. +. +├── Configuration # Configuration +│ ├── Configuration.py # Contains the function to execute for config +│ └── config.ini # contains parameters +├── find_config.py # auto-configurate datasets parameters (max length sentence etc) according to the dataset given +├── requirements.txt # librairies needed +├── Datasets # TLGbank data with links +├── SuperTagger # The Supertagger directory (that you need to install) +│ ├── Datasets # TLGbank data +│ ├── SuperTagger # Implementation of BertForTokenClassification +│ │ ├── SuperTagger.py # Main class +│ │ └── Tagging_bert_model.py # Bert model +│ ├── predict.py # Example of prediction for supertagger +│ └── train.py # Example of train for supertagger +├── Linker # The Linker directory +│ ├── ... +│ └── Linker.py # Linker class containing the neural network +├── models +│ └── supertagger.pt # the pt file contaning the pretrained supertagger (you need to install it) +├── Output # Directory where your linker models will be savec if checkpoint=True in train +├── TensorBoard # Directory where the stats will be savec if tensorboard=True in train +└── train.py # Example of train +``` + ### Dataset format The sentences should be in a column "X", the links with '_x' postfix should be in a column "Y" and the categories in a column "Z". @@ -24,8 +53,8 @@ For the links each atom_x goes with the one and only other atom_x in the sentenc Launch train.py, if you look at it you can give another dataset file and another tagging model. -In train, if you use `checkpoint=True`, the model is automatically saved in a folder: Training_XX-XX_XX-XX. It saves -after each epoch. Use `tensorboard=True` for log in same folder. (`tensorboard --logdir=logs` for see logs) +In train, if you use `checkpoint=True`, the model is automatically saved in a folder: Output/Training_XX-XX_XX-XX. It saves +after each epoch. Use `tensorboard=True` for log saving in folder TensorBoard. (`tensorboard --logdir=logs` for see logs) ## Predicting diff --git a/postprocessing.py b/postprocessing.py index d2d43f0..ff7fbfb 100644 --- a/postprocessing.py +++ b/postprocessing.py @@ -77,7 +77,7 @@ def draw_sentence_output(sentence, categories, links): Drawing the prediction of a sentence when given categories and links predictions :param sentence: list of words :param categories: list of categories - :param links: links predicted + :param links: links predicted, output of predict_with/without_categories :return: dot source """ dot = graphviz.Graph('linking', comment='Axiom linking') -- GitLab