Skip to content
Snippets Groups Projects
Commit b54804c0 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

starting train

parent 8dc363bd
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -9,7 +9,7 @@ from SuperTagger.Linker.AtomEmbedding import AtomEmbedding ...@@ -9,7 +9,7 @@ from SuperTagger.Linker.AtomEmbedding import AtomEmbedding
from SuperTagger.Linker.AtomTokenizer import AtomTokenizer from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
from SuperTagger.Linker.atom_map import atom_map from SuperTagger.Linker.atom_map import atom_map
from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch, mesure_accuracy
from SuperTagger.Linker.AttentionLayer import FFN, AttentionLayer from SuperTagger.Linker.AttentionLayer import FFN, AttentionLayer
from SuperTagger.utils import pad_sequence from SuperTagger.utils import pad_sequence
...@@ -98,11 +98,36 @@ class Linker(Module): ...@@ -98,11 +98,36 @@ class Linker(Module):
return link_weights return link_weights
def predict_axiom_links(self): def predict_axiom_links(self, b_sents_tokenized, b_sents_mask):
return None return None
def eval_batch(self): def eval_batch(self, batch, cross_entropy_loss):
return None b_sents_tokenized = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
b_sents_mask = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
b_category = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
def eval_epoch(self): logits_axiom_links_pred = self.predict_axiom_links(b_sents_tokenized, b_sents_mask)
return None # Softmax and argmax
axiom_links_pred = torch.argmax(torch.nn.functional.softmax(logits_axiom_links_pred, dim=2), dim=2)
accuracy = mesure_accuracy(b_category, axiom_links_pred)
loss = float(cross_entropy_loss(axiom_links_pred, b_category))
return accuracy, loss
def eval_epoch(self, dataloader, cross_entropy_loss):
r"""Average the evaluation of all the batch.
Args:
dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols
"""
accuracy_average = 0
loss_average = 0
compt = 0
for step, batch in enumerate(dataloader):
compt += 1
accuracy, loss = self.eval_batch(batch, cross_entropy_loss)
accuracy_average += accuracy
loss_average += loss
return accuracy_average / compt, loss_average / compt
...@@ -92,3 +92,10 @@ def find_pos_neg_idexes(batch_symbols): ...@@ -92,3 +92,10 @@ def find_pos_neg_idexes(batch_symbols):
list_symbols.append(cut_category_in_symbols(category)) list_symbols.append(cut_category_in_symbols(category))
list_batch.append(list_symbols) list_batch.append(list_symbols)
return list_batch return list_batch
def mesure_accuracy(b_category, axiom_links_pred):
# Convert b_category into
return 0
\ No newline at end of file
#!/bin/sh
#SBATCH --job-name=N-tensorboard
#SBATCH --partition=RTX6000Node
#SBATCH --gres=gpu:1
#SBATCH --mem=32000
#SBATCH --gres-flags=enforce-binding
#SBATCH --error="error_rtx1.err"
#SBATCH --output="out_rtx1.out"
module purge
module load singularity/3.0.3
srun singularity exec /logiciels/containerCollections/CUDA11/pytorch-NGC-21-03-py3.sif python "train.py"
\ No newline at end of file
from Configuration import Configuration
from SuperTagger.Symbol.SymbolTokenizer import SymbolTokenizer
from SuperTagger.utils import read_csv_pgbar
from SuperTagger.Symbol.symbol_map import symbol_map
from collections import Counter
import numpy as np
import statistics
file_path = 'Datasets/m2_dataset.csv'
max_symbols_in_sentence = int(Configuration.modelDecoderConfig['max_symbols_in_sentence'])
max_len_sentence = int(Configuration.modelDecoderConfig['max_len_sentence'])
df = read_csv_pgbar(file_path)
all_symbols = [item for sublist in list(df['sub_tree']) for item in sublist]
counter = Counter(all_symbols)
print(counter)
number_total_symbols = len(all_symbols)
print(number_total_symbols)
most_common_symbol, max_number_in_one_symbol = counter.most_common(1)[0]
print(most_common_symbol)
print(max_number_in_one_symbol)
middle_common_symbol, middle_number_in_one_symbol = counter.most_common(6)[5]
print(middle_common_symbol)
print(middle_number_in_one_symbol)
mean = statistics.mean(counter.values())
print(mean)
def get_weight(count_symbol_x, count_symbol_threashold):
x = count_symbol_threashold / count_symbol_x
return 1 + np.log(x + 1)**2
dic_symbols_weights = {}
for (key, value) in counter.items():
dic_symbols_weights[key] = np.round(get_weight(value, mean), 4)
print(dic_symbols_weights)
list_ordered = []
for symbol in symbol_map.keys():
if symbol != '[START]' and symbol != '[PAD]':
list_ordered.append(dic_symbols_weights[symbol])
print(list_ordered)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment