From b54804c04a42442b54fb90851d2c0908e77b0674 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Thu, 5 May 2022 14:25:37 +0200 Subject: [PATCH] starting train --- SuperTagger/Linker/Linker.py | 37 +++++++++++++++++++++----- SuperTagger/Linker/utils.py | 7 +++++ bash_GPU.sh | 13 --------- weighting.py | 51 ------------------------------------ 4 files changed, 38 insertions(+), 70 deletions(-) delete mode 100644 bash_GPU.sh delete mode 100644 weighting.py diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py index 1f43920..208e8d4 100644 --- a/SuperTagger/Linker/Linker.py +++ b/SuperTagger/Linker/Linker.py @@ -9,7 +9,7 @@ from SuperTagger.Linker.AtomEmbedding import AtomEmbedding from SuperTagger.Linker.AtomTokenizer import AtomTokenizer from SuperTagger.Linker.atom_map import atom_map from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn -from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch +from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch, mesure_accuracy from SuperTagger.Linker.AttentionLayer import FFN, AttentionLayer from SuperTagger.utils import pad_sequence @@ -98,11 +98,36 @@ class Linker(Module): return link_weights - def predict_axiom_links(self): + def predict_axiom_links(self, b_sents_tokenized, b_sents_mask): return None - def eval_batch(self): - return None + def eval_batch(self, batch, cross_entropy_loss): + b_sents_tokenized = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") + b_sents_mask = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") + b_category = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") - def eval_epoch(self): - return None + logits_axiom_links_pred = self.predict_axiom_links(b_sents_tokenized, b_sents_mask) + # Softmax and argmax + axiom_links_pred = torch.argmax(torch.nn.functional.softmax(logits_axiom_links_pred, dim=2), dim=2) + + accuracy = mesure_accuracy(b_category, axiom_links_pred) + loss = float(cross_entropy_loss(axiom_links_pred, b_category)) + + return accuracy, loss + + def eval_epoch(self, dataloader, cross_entropy_loss): + r"""Average the evaluation of all the batch. + + Args: + dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols + """ + accuracy_average = 0 + loss_average = 0 + compt = 0 + for step, batch in enumerate(dataloader): + compt += 1 + accuracy, loss = self.eval_batch(batch, cross_entropy_loss) + accuracy_average += accuracy + loss_average += loss + + return accuracy_average / compt, loss_average / compt diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py index f2e72e1..ce569de 100644 --- a/SuperTagger/Linker/utils.py +++ b/SuperTagger/Linker/utils.py @@ -92,3 +92,10 @@ def find_pos_neg_idexes(batch_symbols): list_symbols.append(cut_category_in_symbols(category)) list_batch.append(list_symbols) return list_batch + + +def mesure_accuracy(b_category, axiom_links_pred): + + # Convert b_category into + + return 0 \ No newline at end of file diff --git a/bash_GPU.sh b/bash_GPU.sh deleted file mode 100644 index 665f769..0000000 --- a/bash_GPU.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/sh -#SBATCH --job-name=N-tensorboard -#SBATCH --partition=RTX6000Node -#SBATCH --gres=gpu:1 -#SBATCH --mem=32000 -#SBATCH --gres-flags=enforce-binding -#SBATCH --error="error_rtx1.err" -#SBATCH --output="out_rtx1.out" - -module purge -module load singularity/3.0.3 - -srun singularity exec /logiciels/containerCollections/CUDA11/pytorch-NGC-21-03-py3.sif python "train.py" \ No newline at end of file diff --git a/weighting.py b/weighting.py deleted file mode 100644 index 0e6ee36..0000000 --- a/weighting.py +++ /dev/null @@ -1,51 +0,0 @@ -from Configuration import Configuration -from SuperTagger.Symbol.SymbolTokenizer import SymbolTokenizer -from SuperTagger.utils import read_csv_pgbar -from SuperTagger.Symbol.symbol_map import symbol_map - -from collections import Counter -import numpy as np - -import statistics - -file_path = 'Datasets/m2_dataset.csv' -max_symbols_in_sentence = int(Configuration.modelDecoderConfig['max_symbols_in_sentence']) -max_len_sentence = int(Configuration.modelDecoderConfig['max_len_sentence']) - -df = read_csv_pgbar(file_path) -all_symbols = [item for sublist in list(df['sub_tree']) for item in sublist] - -counter = Counter(all_symbols) -print(counter) - -number_total_symbols = len(all_symbols) -print(number_total_symbols) - -most_common_symbol, max_number_in_one_symbol = counter.most_common(1)[0] -print(most_common_symbol) -print(max_number_in_one_symbol) - -middle_common_symbol, middle_number_in_one_symbol = counter.most_common(6)[5] -print(middle_common_symbol) -print(middle_number_in_one_symbol) - -mean = statistics.mean(counter.values()) -print(mean) - - -def get_weight(count_symbol_x, count_symbol_threashold): - x = count_symbol_threashold / count_symbol_x - return 1 + np.log(x + 1)**2 - - -dic_symbols_weights = {} -for (key, value) in counter.items(): - dic_symbols_weights[key] = np.round(get_weight(value, mean), 4) -print(dic_symbols_weights) - -list_ordered = [] -for symbol in symbol_map.keys(): - if symbol != '[START]' and symbol != '[PAD]': - list_ordered.append(dic_symbols_weights[symbol]) - -print(list_ordered) -- GitLab