Skip to content
Snippets Groups Projects
Commit f996b207 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

cleaning

parent 74dafadb
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
https://gitlab-ci-token:glpat-AZdpzmAPDFCSK8nPZxCw@gitlab.irit.fr/pnria/global-helper/deepgrail-rnn.git
\ No newline at end of file
import regex
import re
print(re.match(r'([a-zA-Z|_]+)_\d+', "cl_r_1").group(1))
\ No newline at end of file
import pickle
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import SGD, Adam, AdamW
from torch.utils.data import Dataset, TensorDataset, random_split
from transformers import get_cosine_schedule_with_warmup
from Configuration import Configuration
from SuperTagger.Linker.Linker import Linker
from SuperTagger.Linker.atom_map import atom_map
from SuperTagger.eval import SinkhornLoss
from SuperTagger.utils import format_time, read_csv_pgbar
file_path_validation = 'Datasets/aa1_links_dataset_links.csv'
file_path = "Datasets/m2_dataset_V2.csv"
df_axiom_links = read_csv_pgbar(file_path, 10)
data = [['dr(0,np,n)', "n", 'dr(0,dl(0,np,np),n)', 'n', 'dr(0,dl(0,n,n),n)', 'n', 'dl(0,np,txt)']]
linker = Linker()
result = linker(data, [])
print(result.shape)
print(result)
axiom_links_pred = torch.argmax(F.softmax(result, dim=3), dim=3)
print(axiom_links_pred.shape)
print(axiom_links_pred)
\ No newline at end of file
...@@ -31,7 +31,7 @@ atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) ...@@ -31,7 +31,7 @@ atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
# region ParamsTraining # region ParamsTraining
batch_size = int(Configuration.modelTrainingConfig['batch_size']) batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 2 nb_sentences = batch_size * 10
epochs = int(Configuration.modelTrainingConfig['epoch']) epochs = int(Configuration.modelTrainingConfig['epoch'])
seed_val = int(Configuration.modelTrainingConfig['seed_val']) seed_val = int(Configuration.modelTrainingConfig['seed_val'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate']) learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
...@@ -151,7 +151,6 @@ def run_epochs(epochs): ...@@ -151,7 +151,6 @@ def run_epochs(epochs):
# Run the kinker on the categories predictions # Run the kinker on the categories predictions
logits_predictions = linker(batch_atoms, batch_polarity, []) logits_predictions = linker(batch_atoms, batch_polarity, [])
print(logits_predictions.permute(1, 0, 2, 3).shape)
linker_loss = cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links) linker_loss = cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links)
# Perform a backward pass to calculate the gradients. # Perform a backward pass to calculate the gradients.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment