Skip to content
Snippets Groups Projects
Commit 06615a06 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

works, 70% accuracy, need learn padding

parent b012fcf5
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
Showing with 227 additions and 128 deletions
......@@ -3,17 +3,22 @@ transformers = 4.16.2
[DATASET_PARAMS]
symbols_vocab_size=26
atom_vocab_size=18
max_len_sentence=290
max_atoms_in_sentence=1250
max_atoms_in_one_type=510
max_atoms_in_sentence=874
max_atoms_in_one_type=324
[MODEL_ENCODER]
dim_encoder = 768
[MODEL_LINKER]
dim_cat_out=768
dim_intermediate_FFN=256
dim_pre_sinkhorn_transfo=32
nhead=4
dim_emb_atom = 256
num_layers=2
dim_cat_inter=512
dim_cat_out=256
dim_intermediate_FFN=128
dim_pre_sinkhorn_transfo=64
dropout=0.1
sinkhorn_iters=5
......@@ -21,4 +26,4 @@ sinkhorn_iters=5
batch_size=32
epoch=25
seed_val=42
learning_rate=2e-3
learning_rate=2e-3
\ No newline at end of file
import math
import os
import re
import sys
......@@ -7,7 +8,8 @@ import time
import torch
import torch.nn.functional as F
from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout
from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout, TransformerEncoderLayer, TransformerEncoder, \
Embedding
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import TensorDataset, random_split
......@@ -15,15 +17,17 @@ from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from Configuration import Configuration
from Linker.AtomTokenizer import AtomTokenizer
from Linker.PositionEncoding import PositionalEncoding
from Linker.PositionalEncoding import PositionalEncoding
from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from Linker.AtomTokenizer import AtomTokenizer
from Linker.atom_map import atom_map, atom_map_redux
from Linker.eval import mesure_accuracy, SinkhornLoss
from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch
from Supertagger import SuperTagger
from utils import pad_sequence
import torch
def format_time(elapsed):
'''
......@@ -49,49 +53,73 @@ def output_create_dir():
return training_dir, writer
def generate_square_subsequent_mask(sz):
"""Generates an upper-triangular matrix of -inf, with zeros on diag."""
return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
class Linker(Module):
def __init__(self, supertagger_path_model):
super(Linker, self).__init__()
# region parameters
dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
# atom settings
atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
# Transformer
self.nhead = int(Configuration.modelLinkerConfig['nhead'])
self.dim_emb_atom = int(Configuration.modelLinkerConfig['dim_emb_atom'])
self.num_layers = int(Configuration.modelLinkerConfig['num_layers'])
# torch cat
self.dim_cat_inter = int(Configuration.modelLinkerConfig['dim_cat_out'])
self.dim_cat_out = int(Configuration.modelLinkerConfig['dim_cat_out'])
dim_pre_sinkhorn_transfo = int(Configuration.modelLinkerConfig['dim_pre_sinkhorn_transfo'])
dim_intermediate_FFN = int(Configuration.modelLinkerConfig['dim_intermediate_FFN'])
dim_pre_sinkhorn_transfo = int(Configuration.modelLinkerConfig['dim_pre_sinkhorn_transfo'])
# sinkhorn
self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
dropout = float(Configuration.modelLinkerConfig['dropout'])
# settings
self.batch_size = int(Configuration.modelTrainingConfig['batch_size'])
self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# endregion
# Supertagger for categories
supertagger = SuperTagger()
supertagger.load_weights(supertagger_path_model)
self.Supertagger = supertagger
self.Supertagger.model.to(self.device)
self.atom_map = atom_map
# Atoms embedding
self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
self.atom_map_redux = atom_map_redux
self.sub_atoms_type_list = list(atom_map_redux.keys())
self.padding_id = self.atom_map['[PAD]']
self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
self.inverse_map = self.atoms_tokenizer.inverse_atom_map
self.position_encoding = PositionalEncoding(dim_encoder, max_len=self.max_atoms_in_sentence)
dim_cat = dim_encoder * 2
self.linker_encoder = Linear(dim_cat, self.dim_cat_out, bias=False)
self.dropout = Dropout(dropout)
self.atom_encoder = Embedding(self.max_atoms_in_sentence, self.dim_emb_atom, padding_idx=atom_map["[PAD]"])
self.atom_encoder.weight.data.uniform_(-0.1, 0.1)
self.position_encoder = PositionalEncoding(self.dim_emb_atom, 0.1, max_len=self.max_atoms_in_sentence)
encoder_layer = TransformerEncoderLayer(d_model=self.dim_emb_atom, nhead=self.nhead)
self.transformer = TransformerEncoder(encoder_layer, num_layers=self.num_layers)
# Concatenation with word embedding
dim_cat = dim_encoder + self.dim_emb_atom
self.linker_encoder = Sequential(
FFN(dim_cat, self.dim_cat_inter, 0.1, d_out=self.dim_cat_out),
LayerNorm(self.dim_cat_out, eps=1e-8)
)
# Division into positive and negative
self.pos_transformation = Sequential(
FFN(self.dim_cat_out, dim_intermediate_FFN, 0.1, d_out=dim_pre_sinkhorn_transfo),
LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-12)
LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-8)
)
self.neg_transformation = Sequential(
FFN(self.dim_cat_out, dim_intermediate_FFN, 0.1, d_out=dim_pre_sinkhorn_transfo),
LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-12)
LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-8)
)
# Learning
self.cross_entropy_loss = SinkhornLoss()
self.optimizer = AdamW(self.parameters(),
lr=learning_rate)
......@@ -113,20 +141,21 @@ class Linker(Module):
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"])
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(
list(map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch)))
num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence)
pos_idx = get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, self.max_atoms_in_one_type)
neg_idx = get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, self.max_atoms_in_one_type)
pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
neg_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, self.sub_atoms_type_list, atoms_polarity_batch,
truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
df_axiom_links["Y"])
truth_links_batch = truth_links_batch.permute(1, 0, 2)
# Construction tensor dataset
dataset = TensorDataset(num_atoms_per_word, pos_idx, neg_idx, truth_links_batch, sentences_tokens,
sentences_mask)
dataset = TensorDataset(num_atoms_per_word, atoms_batch_tokenized, pos_idx, neg_idx, truth_links_batch,
sentences_tokens, sentences_mask)
if validation_rate > 0.0:
train_size = int(0.9 * len(dataset))
......@@ -141,38 +170,38 @@ class Linker(Module):
print("End preprocess Data")
return training_dataloader, validation_dataloader
def forward(self, batch_num_atoms_per_word, batch_pos_idx, batch_neg_idx, sents_embedding, cat_embedding):
def forward(self, batch_num_atoms_per_word, batch_atoms, src_mask, batch_pos_idx, batch_neg_idx, sents_embedding):
r"""
Args:
batch_num_atoms_per_word : (batch_size, len_sentence) flattened categories
batch_atoms : atoms tok
src_mask : atoms mask
batch_pos_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities
batch_neg_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities
sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
cat_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for cat embedding
Returns:
link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) log probabilities
"""
# repeat embedding word for each atom in word
# repeat embedding word for each atom in word with a +1 for sep
sents_embedding_repeat = pad_sequence(
[torch.repeat_interleave(input=sents_embedding[i], repeats=batch_num_atoms_per_word[i], dim=0)
for i in range(len(sents_embedding))], max_len=self.max_atoms_in_sentence, padding_value=0)
cat_embedding_repeat = pad_sequence(
[torch.repeat_interleave(input=cat_embedding[i], repeats=batch_num_atoms_per_word[i], dim=0)
for i in range(len(cat_embedding))], max_len=self.max_atoms_in_sentence, padding_value=0)
# positional encoding of atoms and cat embedding to form the atom embedding
position_encoding = self.position_encoding(cat_embedding_repeat)
atoms_embedding = self.atom_encoder(batch_atoms) * math.sqrt(self.dim_emb_atom)
atoms_embedding = self.position_encoder(atoms_embedding)
atoms_embedding = atoms_embedding.permute(1, 0, 2)
atoms_embedding = self.transformer(atoms_embedding, src_mask)
atoms_embedding = atoms_embedding.permute(1, 0, 2)
# cat
atoms_sentences_encoding = torch.cat([sents_embedding_repeat, position_encoding], dim=2)
atoms_sentences_encoding = torch.cat([sents_embedding_repeat, atoms_embedding], dim=2)
atoms_encoding = self.linker_encoder(atoms_sentences_encoding)
atoms_encoding = self.dropout(atoms_encoding)
# linking per atom type
batch_size, atom_vocan_size, _ = batch_pos_idx.shape
link_weights = torch.zeros(atom_vocan_size, batch_size, self.max_atoms_in_one_type // 2,
self.max_atoms_in_one_type // 2, device=self.device)
for atom_type in self.sub_atoms_type_list:
for atom_type in list(atom_map_redux.keys()):
pos_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_pos_idx, atom_type)
neg_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_neg_idx, atom_type)
......@@ -251,23 +280,25 @@ class Linker(Module):
# For each batch of training data...
with tqdm(training_dataloader, unit="batch") as tepoch:
src_mask = generate_square_subsequent_mask(self.max_atoms_in_sentence).to(self.device)
for batch in tepoch:
# Unpack this training batch from our dataloader
batch_num_atoms = batch[0].to(self.device)
batch_pos_idx = batch[1].to(self.device)
batch_neg_idx = batch[2].to(self.device)
batch_true_links = batch[3].to(self.device)
batch_sentences_tokens = batch[4].to(self.device)
batch_sentences_mask = batch[5].to(self.device)
batch_atoms_tok = batch[1].to(self.device)
batch_pos_idx = batch[2].to(self.device)
batch_neg_idx = batch[3].to(self.device)
batch_true_links = batch[4].to(self.device)
batch_sentences_tokens = batch[5].to(self.device)
batch_sentences_mask = batch[6].to(self.device)
self.optimizer.zero_grad()
# get sentence embedding from BERT which is already trained
output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
# Run the kinker on the categories predictions
logits_predictions = self(batch_num_atoms, batch_pos_idx, batch_neg_idx, output['word_embeding'],
output['last_hidden_state'])
# Run the Linker on the atoms
logits_predictions = self(batch_num_atoms, batch_atoms_tok, src_mask, batch_pos_idx, batch_neg_idx,
output['word_embeding'])
linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
# Perform a backward pass to calculate the gradients.
......@@ -294,21 +325,22 @@ class Linker(Module):
def eval_batch(self, batch):
batch_num_atoms = batch[0].to(self.device)
batch_pos_idx = batch[1].to(self.device)
batch_neg_idx = batch[2].to(self.device)
batch_true_links = batch[3].to(self.device)
batch_sentences_tokens = batch[4].to(self.device)
batch_sentences_mask = batch[5].to(self.device)
batch_atoms_tok = batch[1].to(self.device)
batch_pos_idx = batch[2].to(self.device)
batch_neg_idx = batch[3].to(self.device)
batch_true_links = batch[4].to(self.device)
batch_sentences_tokens = batch[5].to(self.device)
batch_sentences_mask = batch[6].to(self.device)
output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
logits_predictions = self(batch_num_atoms, batch_pos_idx, batch_neg_idx, output['word_embeding'],
output['last_hidden_state']) # atom_vocab, batch_size, max atoms in one type, max atoms in one type
src_mask = generate_square_subsequent_mask(self.max_atoms_in_sentence).to(self.device)
logits_predictions = self(batch_num_atoms, batch_atoms_tok, src_mask, batch_pos_idx, batch_neg_idx, output[
'word_embeding']) # atom_vocab, batch_size, max atoms in one type, max atoms in one type
axiom_links_pred = torch.argmax(logits_predictions, dim=3) # atom_vocab, batch_size, max atoms in one type
print('\n')
print("Tokens de la phrase : ", batch_sentences_tokens[1])
print("Polarités + des atoms de la phrase : ", batch_pos_idx[1][2][:50])
print("Polarités - des atoms de la phrase : ", batch_neg_idx[1][2][:50])
print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100])
print("Les prédictions : ", axiom_links_pred[2][1][:100])
print('\n')
......@@ -340,10 +372,11 @@ class Linker(Module):
try:
params = torch.load(model_file, map_location=self.device)
args = params['args']
self.atom_map = args['atom_map']
self.max_atoms_in_sentence = args['max_atoms_in_sentence']
self.atoms_tokenizer = AtomTokenizer(self.atom_map, self.max_atoms_in_sentence)
self.atoms_embedding.load_state_dict(params['atoms_embedding'])
self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
self.atom_encoder.load_state_dict(params['atom_encoder'])
self.position_encoder.load_state_dict(params['position_encoder'])
self.transformer.load_state_dict(params['transformer'])
self.linker_encoder.load_state_dict(params['linker_encoder'])
self.pos_transformation.load_state_dict(params['pos_transformation'])
self.neg_transformation.load_state_dict(params['neg_transformation'])
......@@ -361,8 +394,9 @@ class Linker(Module):
self.cpu()
torch.save({
'args': dict(atom_map=self.atom_map, max_atoms_in_sentence=self.max_atoms_in_sentence),
'atoms_embedding': self.atoms_embedding.state_dict(),
'atom_encoder': self.atom_encoder.state_dict(),
'position_encoder': self.position_encoder,
'transformer': self.transformer,
'linker_encoder': self.linker_encoder.state_dict(),
'pos_transformation': self.pos_transformation.state_dict(),
'neg_transformation': self.neg_transformation.state_dict(),
......@@ -384,4 +418,4 @@ class Linker(Module):
return torch.stack([torch.stack([bsd_tensor.select(0, index=i).select(0, index=int(atom)).to(self.device)
if atom != -1 else torch.zeros(self.dim_cat_out, device=self.device)
for atom in sentence])
for i, sentence in enumerate(positional_ids[:, atom_map_redux[atom_type], :])])
for i, sentence in enumerate(positional_ids[:, self.atom_map_redux[atom_type], :])])
......@@ -5,7 +5,7 @@ import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
......@@ -19,7 +19,7 @@ class PositionalEncoding(nn.Module):
def forward(self, x):
"""
Args:
x: Tensor, shape [batch_size,seq_len, embedding_dim]
x: Tensor, shape [batch_size, seq_len, mbedding_dim]
"""
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
from .Linker import Linker
from .atom_map import atom_map
from .AtomTokenizer import AtomTokenizer
from .PositionEncoding import PositionalEncoding
\ No newline at end of file
from .Sinkhorn import *
\ No newline at end of file
......@@ -15,7 +15,8 @@ atom_map = \
'txt': 13,
's': 14,
's_ppart': 15,
'[PAD]': 16
"[SEP]":16,
'[PAD]': 17
}
atom_map_redux = {
......
......@@ -17,8 +17,8 @@ def mesure_accuracy(batch_true_links, axiom_links_pred, max_atoms_in_one_type):
batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
"""
padding = max_atoms_in_one_type // 2 -1
batch_true_links=batch_true_links.permute(1, 0, 2)
padding = max_atoms_in_one_type // 2 - 1
batch_true_links = batch_true_links.permute(1, 0, 2)
correct_links = torch.ones(axiom_links_pred.size())
correct_links[axiom_links_pred != batch_true_links] = 0
correct_links[batch_true_links == padding] = 1
......@@ -26,4 +26,5 @@ def mesure_accuracy(batch_true_links, axiom_links_pred, max_atoms_in_one_type):
num_masked_atoms = len(batch_true_links[batch_true_links == padding])
# diviser par nombre de links
return (num_correct_links - num_masked_atoms)/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms)
return (num_correct_links - num_masked_atoms) / (
axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms)
......@@ -3,6 +3,7 @@ import regex
import torch
from torch.nn import Sequential, Linear, Dropout, GELU
from torch.nn import Module
from Linker.atom_map import atom_map, atom_map_redux
from utils import pad_sequence
......@@ -28,34 +29,34 @@ regex_categories_axiom_links = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)
regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
################################ Liste des atoms avec _i ########################################
def get_axiom_links(max_atoms_in_one_type, sub_atoms_type_list, atoms_polarity, batch_axiom_links):
# region get true axiom links
def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
r"""
Args:
max_atoms_in_one_type : configuration
sub_atoms_type_list : list of atom type to match
atoms_polarity : (batch_size, max_atoms_in_sentence)
batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms
Returns:
batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
"""
atoms_batch = get_atoms_links_batch(batch_axiom_links)
atoms_batch = list(map(lambda sentence: sentence.split(" "), atoms_batch))
linking_plus_to_minus_all_types = []
for atom_type in sub_atoms_type_list:
for atom_type in list(atom_map_redux.keys()):
# filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
l_polarity_plus = [[x for i, x in enumerate(atoms_batch[s_idx]) if atoms_polarity[s_idx, i]
and bool(re.search(atom_type + "_", atoms_batch[s_idx][i]))] for s_idx in
and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx in
range(len(atoms_batch))]
l_polarity_minus = [[x for i, x in enumerate(atoms_batch[s_idx]) if not atoms_polarity[s_idx, i]
and bool(re.search(atom_type + "_", atoms_batch[s_idx][i]))] for s_idx in
and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx in
range(len(atoms_batch))]
linking_plus_to_minus = pad_sequence(
[torch.as_tensor(
[l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else max_atoms_in_one_type // 2 -1 for
i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2,
padding_value=max_atoms_in_one_type // 2 -1)
[l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else max_atoms_in_one_type // 2 - 1
for i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2,
padding_value=max_atoms_in_one_type // 2 - 1)
linking_plus_to_minus_all_types.append(linking_plus_to_minus)
......@@ -74,15 +75,13 @@ def category_to_atoms_axiom_links(category, categories_to_atoms):
if category.startswith("GOAL:"):
word, cat = category.split(':')
return category_to_atoms_axiom_links(cat, categories_to_atoms)
elif category == "let":
return []
elif True in res:
return [category]
return " " + category
else:
category_cut = regex.match(regex_categories_axiom_links, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut:
categories_to_atoms += category_to_atoms_axiom_links(cat, [])
categories_to_atoms += category_to_atoms_axiom_links(cat, "")
return categories_to_atoms
......@@ -95,14 +94,26 @@ def get_atoms_links_batch(category_batch):
"""
batch = []
for sentence in category_batch:
categories_to_atoms = []
categories_to_atoms = ""
for category in sentence:
categories_to_atoms += category_to_atoms_axiom_links(category, [])
if category != "let" and not category.startswith("GOAL:"):
categories_to_atoms += category_to_atoms_axiom_links(category, "")
categories_to_atoms += " [SEP]"
categories_to_atoms = categories_to_atoms.lstrip()
elif category.startswith("GOAL:"):
categories_to_atoms += category_to_atoms_axiom_links(category, "")
categories_to_atoms = categories_to_atoms.lstrip()
batch.append(categories_to_atoms)
return batch
################################ Liste des atoms ########################################
print("test to create links ",
get_axiom_links(20, torch.stack([torch.as_tensor([False, True, False, False, False, True, False, True, False, False, True, False, False, False, True, False, False, True, False, True, False, False, True, False, False, False, True])]),
[['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']]))
# endregion
# region get atoms in sentence
def category_to_atoms(category, categories_to_atoms):
r"""
......@@ -116,15 +127,13 @@ def category_to_atoms(category, categories_to_atoms):
if category.startswith("GOAL:"):
word, cat = category.split(':')
return category_to_atoms(cat, categories_to_atoms)
elif category == "let":
return []
elif True in res:
return [category]
return " " + category
else:
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut:
categories_to_atoms += category_to_atoms(cat, [])
categories_to_atoms += category_to_atoms(cat, "")
return categories_to_atoms
......@@ -137,14 +146,22 @@ def get_atoms_batch(category_batch):
"""
batch = []
for sentence in category_batch:
categories_to_atoms = []
categories_to_atoms = ""
for category in sentence:
categories_to_atoms += category_to_atoms(category, [])
if category != "let":
categories_to_atoms += category_to_atoms(category, "")
categories_to_atoms += " [SEP]"
categories_to_atoms = categories_to_atoms.lstrip()
batch.append(categories_to_atoms)
return batch
################################ Liste des atoms ########################################
print(" test for get atoms in categories on ['dr(0,s,np)', 'let']", get_atoms_batch([["dr(0,s,np)", "let"]]))
# endregion
# region calculate num atoms per category
def category_to_num_atoms(category, categories_to_atoms):
r"""
......@@ -182,12 +199,22 @@ def get_num_atoms_batch(category_batch, max_len_sentence):
for sentence in category_batch:
num_atoms_sentence = []
for category in sentence:
num_atoms_sentence.append(category_to_num_atoms(category, 0))
num_atoms_in_word = category_to_num_atoms(category, 0)
# add 1 because for word we have SEP at the end
if category != "let":
num_atoms_in_word += 1
num_atoms_sentence.append(num_atoms_in_word)
batch.append(torch.as_tensor(num_atoms_sentence))
return pad_sequence(batch, max_len=max_len_sentence, padding_value=0)
################################ Polarity ###############################################
print(" test for get number of atoms in categories on ['dr(0,s,np)', 'let']",
get_num_atoms_batch([["dr(0,s,np)", "let"]], 10))
# endregion
# region get polarity
def category_to_atoms_polarity(category, polarity):
r"""
......@@ -207,8 +234,6 @@ def category_to_atoms_polarity(category, polarity):
category_to_polarity.append(True)
else:
category_to_polarity += category_to_atoms_polarity(cat, True)
elif category == "let":
pass
# le mot a une category atomique
elif True in res:
category_to_polarity.append(not polarity)
......@@ -270,58 +295,91 @@ def find_pos_neg_idexes(atoms_batch):
for sentence in atoms_batch:
list_atoms = []
for category in sentence:
for at in category_to_atoms_polarity(category, True):
list_atoms.append(at)
if category == "let":
pass
else:
for at in category_to_atoms_polarity(category, True):
list_atoms.append(at)
list_atoms.append(False)
list_batch.append(list_atoms)
return list_batch
################################ GOAL ###############################################
print(" test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']",
find_pos_neg_idexes([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']]))
# endregion
# region get atoms and polarities with GOAL
def get_GOAL(max_atoms_in_sentence, categories_batch):
polarities = find_pos_neg_idexes(categories_batch)
atoms_batch = get_atoms_batch(categories_batch)
atoms_batch_for_polarities = list(
map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch))
for s_idx in range(len(atoms_batch)):
for atom_type in list(atom_map.keys()):
list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
and atoms_batch[s_idx][i] == atom_type]
list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
and atoms_batch[s_idx][i] == atom_type]
for atom_type in list(atom_map_redux.keys()):
list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))]
list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))]
while len(list_minus) != len(list_plus):
if len(list_minus) > len(list_plus):
atoms_batch[s_idx].append(atom_type)
atoms_batch[s_idx] += " " + atom_type
atoms_batch_for_polarities[s_idx].append(atom_type)
polarities[s_idx].append(True)
else:
atoms_batch[s_idx].append(atom_type)
atoms_batch[s_idx] += " " + atom_type
atoms_batch_for_polarities[s_idx].append(atom_type)
polarities[s_idx].append(False)
list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
and atoms_batch[s_idx][i] == atom_type]
list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
and atoms_batch[s_idx][i] == atom_type]
list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i]
and atoms_batch_for_polarities[s_idx][i] == atom_type]
list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i]
and atoms_batch_for_polarities[s_idx][i] == atom_type]
return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
max_len=max_atoms_in_sentence, padding_value=0)
################################ Prepare encoding ###############################################
print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)", "s"]]))
# endregion
# region get idx for pos and neg
def get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, max_atoms_in_one_type):
inverse_atom_map = {v: k for k, v in atom_map.items()}
def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
atoms_batch_for_polarities = list(
map(lambda sentence: sentence.split(" "), atoms_batch))
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(
re.match(r"" + atom_type + "_?\w*", inverse_atom_map[int(atoms_batch_tokenized[s_idx][i])])) and
atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in
enumerate(atoms_batch_tokenized)], max_len=max_atoms_in_one_type // 2, padding_value=-1)
re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i])) and
atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch_for_polarities)],
max_len=max_atoms_in_one_type // 2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())]
return torch.stack(pos_idx).permute(1, 0, 2)
def get_neg_idx(atoms_batch_tokenized, atoms_polarity_batch, max_atoms_in_one_type):
inverse_atom_map = {v: k for k, v in atom_map.items()}
neg_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(
re.match(r"" + atom_type + "_?\w*", inverse_atom_map[int(atoms_batch_tokenized[s_idx][i])])) and
not atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in
enumerate(atoms_batch_tokenized)], max_len=max_atoms_in_one_type // 2, padding_value=-1)
def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
atoms_batch_for_polarities = list(
map(lambda sentence: sentence.split(" "), atoms_batch))
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(
re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i])) and not
atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch_for_polarities)],
max_len=max_atoms_in_one_type // 2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())]
return torch.stack(neg_idx).permute(1, 0, 2)
return torch.stack(pos_idx).permute(1, 0, 2)
print(" test for cut into pos neg on ['s np [SEP] s [SEP] np s s n n']", get_neg_idx(['s np [SEP] s [SEP] np s s n n'],
torch.as_tensor(
[[False, True, False, False,
False, False, True, True,
False, True,
False, False]]), 10))
# endregion
#!/bin/sh
#SBATCH --job-name=Deepgrail_Linker_9000
#SBATCH --partition=RTX6000Node
#SBATCH --job-name=Deepgrail_Linker
#SBATCH --partition=GPUNodes
#SBATCH --gres=gpu:1
#SBATCH --mem=32000
#SBATCH --gres-flags=enforce-binding
......
scp -r cdepourt@osirim-slurm.irit.fr:projets/deepgrail2/deepgrail_RNN_with_linker/TensorBoard/Tranning_19-05_09-49/logs /home/cdepourt/Bureau/deepgrail_RNN_with_linker/logs
scp -r cdepourt@osirim-slurm.irit.fr:projets/deepgrailGPU1/deepgrail_RNN_with_linker/TensorBoard/ /home/cdepourt/Bureau/deepgrail_RNN_with_linker/TensorBoard
rsync -av -e ssh --exclude="__pycache__" --exclude="venv" --exclude=".git" --exclude=".idea" -r /home/cdepourt/Bureau/deepgrail_RNN_with_linker cdepourt@osirim-slurm.irit.fr:projets/deepgrail2
File added
File added
File added
File added
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment