Skip to content
Snippets Groups Projects
Commit a539008d authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

change padding handling

parent 06615a06
Branches
No related tags found
3 merge requests!6Linker with transformer,!5Linker with transformer,!3Working on padding
......@@ -5,15 +5,16 @@ transformers = 4.16.2
symbols_vocab_size=26
atom_vocab_size=18
max_len_sentence=290
max_atoms_in_sentence=874
max_atoms_in_sentence=875
max_atoms_in_one_type=324
[MODEL_ENCODER]
dim_encoder = 768
[MODEL_LINKER]
nhead=4
nhead=8
dim_emb_atom = 256
dim_feedforward_transformer = 768
num_layers=2
dim_cat_inter=512
dim_cat_out=256
......
......@@ -69,6 +69,7 @@ class Linker(Module):
# Transformer
self.nhead = int(Configuration.modelLinkerConfig['nhead'])
self.dim_emb_atom = int(Configuration.modelLinkerConfig['dim_emb_atom'])
self.dim_feedforward_transformer = int(Configuration.modelLinkerConfig['dim_feedforward_transformer'])
self.num_layers = int(Configuration.modelLinkerConfig['num_layers'])
# torch cat
self.dim_cat_inter = int(Configuration.modelLinkerConfig['dim_cat_out'])
......@@ -78,7 +79,6 @@ class Linker(Module):
# sinkhorn
self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
# settings
self.batch_size = int(Configuration.modelTrainingConfig['batch_size'])
self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
......@@ -95,11 +95,13 @@ class Linker(Module):
# Atoms embedding
self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
self.atom_map_redux = atom_map_redux
self.padding_id = atom_map["[PAD]"]
self.sub_atoms_type_list = list(atom_map_redux.keys())
self.atom_encoder = Embedding(self.max_atoms_in_sentence, self.dim_emb_atom, padding_idx=atom_map["[PAD]"])
self.atom_encoder = Embedding(atom_vocab_size, self.dim_emb_atom, padding_idx=self.padding_id)
self.atom_encoder.weight.data.uniform_(-0.1, 0.1)
self.position_encoder = PositionalEncoding(self.dim_emb_atom, 0.1, max_len=self.max_atoms_in_sentence)
encoder_layer = TransformerEncoderLayer(d_model=self.dim_emb_atom, nhead=self.nhead)
encoder_layer = TransformerEncoderLayer(d_model=self.dim_emb_atom, nhead=self.nhead,
dim_feedforward=self.dim_feedforward_transformer, dropout=0.1)
self.transformer = TransformerEncoder(encoder_layer, num_layers=self.num_layers)
# Concatenation with word embedding
......@@ -146,8 +148,8 @@ class Linker(Module):
num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence)
pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
neg_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
neg_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
df_axiom_links["Y"])
......@@ -170,12 +172,11 @@ class Linker(Module):
print("End preprocess Data")
return training_dataloader, validation_dataloader
def forward(self, batch_num_atoms_per_word, batch_atoms, src_mask, batch_pos_idx, batch_neg_idx, sents_embedding):
def forward(self, batch_num_atoms_per_word, batch_atoms, batch_pos_idx, batch_neg_idx, sents_embedding):
r"""
Args:
batch_num_atoms_per_word : (batch_size, len_sentence) flattened categories
batch_atoms : atoms tok
src_mask : atoms mask
batch_pos_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities
batch_neg_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities
sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
......@@ -187,10 +188,14 @@ class Linker(Module):
[torch.repeat_interleave(input=sents_embedding[i], repeats=batch_num_atoms_per_word[i], dim=0)
for i in range(len(sents_embedding))], max_len=self.max_atoms_in_sentence, padding_value=0)
# atoms emebedding
src_key_padding_mask = torch.eq(batch_atoms, self.padding_id)
src_mask = generate_square_subsequent_mask(self.max_atoms_in_sentence).to(self.device)
atoms_embedding = self.atom_encoder(batch_atoms) * math.sqrt(self.dim_emb_atom)
atoms_embedding = self.position_encoder(atoms_embedding)
atoms_embedding = atoms_embedding.permute(1, 0, 2)
atoms_embedding = self.transformer(atoms_embedding, src_mask)
atoms_embedding = self.transformer(atoms_embedding, src_mask,
src_key_padding_mask=src_key_padding_mask)
atoms_embedding = atoms_embedding.permute(1, 0, 2)
# cat
......@@ -280,7 +285,6 @@ class Linker(Module):
# For each batch of training data...
with tqdm(training_dataloader, unit="batch") as tepoch:
src_mask = generate_square_subsequent_mask(self.max_atoms_in_sentence).to(self.device)
for batch in tepoch:
# Unpack this training batch from our dataloader
batch_num_atoms = batch[0].to(self.device)
......@@ -297,10 +301,10 @@ class Linker(Module):
output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
# Run the Linker on the atoms
logits_predictions = self(batch_num_atoms, batch_atoms_tok, src_mask, batch_pos_idx, batch_neg_idx,
logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx,
output['word_embeding'])
linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links, self.max_atoms_in_one_type)
# Perform a backward pass to calculate the gradients.
epoch_loss += float(linker_loss)
linker_loss.backward()
......@@ -334,19 +338,17 @@ class Linker(Module):
output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
src_mask = generate_square_subsequent_mask(self.max_atoms_in_sentence).to(self.device)
logits_predictions = self(batch_num_atoms, batch_atoms_tok, src_mask, batch_pos_idx, batch_neg_idx, output[
logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, output[
'word_embeding']) # atom_vocab, batch_size, max atoms in one type, max atoms in one type
axiom_links_pred = torch.argmax(logits_predictions, dim=3) # atom_vocab, batch_size, max atoms in one type
print('\n')
print("Tokens de la phrase : ", batch_sentences_tokens[1])
print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100])
print("Les prédictions : ", axiom_links_pred[2][1][:100])
print('\n')
accuracy = mesure_accuracy(batch_true_links, axiom_links_pred, self.max_atoms_in_one_type)
loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
loss = self.cross_entropy_loss(logits_predictions, batch_true_links, self.max_atoms_in_one_type)
return loss, accuracy
......
import torch
from torch.nn import Module
from torch.nn.functional import nll_loss
from Linker.atom_map import atom_map, atom_map_redux
class SinkhornLoss(Module):
def __init__(self):
super(SinkhornLoss, self).__init__()
def forward(self, predictions, truths):
return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean')
def forward(self, predictions, truths, max_atoms_in_one_type):
return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1)
for link, perm in zip(predictions, truths.permute(1, 0, 2)))
......@@ -17,7 +18,7 @@ def mesure_accuracy(batch_true_links, axiom_links_pred, max_atoms_in_one_type):
batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
"""
padding = max_atoms_in_one_type // 2 - 1
padding = -1
batch_true_links = batch_true_links.permute(1, 0, 2)
correct_links = torch.ones(axiom_links_pred.size())
correct_links[axiom_links_pred != batch_true_links] = 0
......
......@@ -45,18 +45,18 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
for atom_type in list(atom_map_redux.keys()):
# filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
l_polarity_plus = [[x for i, x in enumerate(atoms_batch[s_idx]) if atoms_polarity[s_idx, i]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx in
range(len(atoms_batch))]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx
in range(len(atoms_batch))]
l_polarity_minus = [[x for i, x in enumerate(atoms_batch[s_idx]) if not atoms_polarity[s_idx, i]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx in
range(len(atoms_batch))]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx
in range(len(atoms_batch))]
linking_plus_to_minus = pad_sequence(
[torch.as_tensor(
[l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else max_atoms_in_one_type // 2 - 1
[l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else -1
for i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2,
padding_value=max_atoms_in_one_type // 2 - 1)
padding_value=-1)
linking_plus_to_minus_all_types.append(linking_plus_to_minus)
......@@ -108,8 +108,12 @@ def get_atoms_links_batch(category_batch):
print("test to create links ",
get_axiom_links(20, torch.stack([torch.as_tensor([False, True, False, False, False, True, False, True, False, False, True, False, False, False, True, False, False, True, False, True, False, False, True, False, False, False, True])]),
[['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']]))
get_axiom_links(20, torch.stack([torch.as_tensor(
[False, True, False, False, False, True, False, True, False, False, True, False, False, False, True, False,
False, True, False, True, False, False, True, False, False, False, True])]),
[['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)',
'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']]))
# endregion
......@@ -305,8 +309,10 @@ def find_pos_neg_idexes(atoms_batch):
return list_batch
print(" test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']",
find_pos_neg_idexes([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']]))
print(
" test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']",
find_pos_neg_idexes([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)',
'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']]))
# endregion
......@@ -349,11 +355,12 @@ print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)",
# region get idx for pos and neg
def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
atoms_batch_for_polarities = list(
map(lambda sentence: sentence.split(" "), atoms_batch))
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(
re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i])) and
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z",
atoms_batch_for_polarities[s_idx][i])) and
atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch_for_polarities)],
max_len=max_atoms_in_one_type // 2, padding_value=-1)
......@@ -362,11 +369,12 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
return torch.stack(pos_idx).permute(1, 0, 2)
def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
atoms_batch_for_polarities = list(
map(lambda sentence: sentence.split(" "), atoms_batch))
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(
re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i])) and not
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z",
atoms_batch_for_polarities[s_idx][i])) and not
atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch_for_polarities)],
max_len=max_atoms_in_one_type // 2, padding_value=-1)
......@@ -380,6 +388,6 @@ print(" test for cut into pos neg on ['s np [SEP] s [SEP] np s s n n']", get_neg
[[False, True, False, False,
False, False, True, True,
False, True,
False, False]]), 10))
False, False]]), 10, 50))
# endregion
#!/bin/sh
#SBATCH --job-name=Deepgrail_Linker
#SBATCH --partition=GPUNodes
#SBATCH --partition=RTX6000Node
#SBATCH --gres=gpu:1
#SBATCH --mem=32000
#SBATCH --gres-flags=enforce-binding
......
File deleted
File deleted
File deleted
File deleted
File deleted
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment