Skip to content
Snippets Groups Projects
Commit 2087b8ae authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

adding goal linking

parent b04c7ce2
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -12,19 +12,19 @@ max_atoms_in_one_type=324 ...@@ -12,19 +12,19 @@ max_atoms_in_one_type=324
dim_encoder = 768 dim_encoder = 768
[MODEL_LINKER] [MODEL_LINKER]
nhead=16 nhead=8
dim_emb_atom = 256 dim_emb_atom = 256
dim_feedforward_transformer = 512 dim_feedforward_transformer = 768
num_layers=3 num_layers=3
dim_cat_inter=512 dim_cat_inter=512
dim_cat_out=256 dim_cat_out=256
dim_intermediate_FFN=128 dim_intermediate_FFN=128
dim_pre_sinkhorn_transfo=64 dim_pre_sinkhorn_transfo=32
dropout=0.1 dropout=0.1
sinkhorn_iters=5 sinkhorn_iters=5
[MODEL_TRAINING] [MODEL_TRAINING]
batch_size=16 batch_size=32
epoch=30 epoch=30
seed_val=42 seed_val=42
learning_rate=2e-3 learning_rate=2e-3
\ No newline at end of file
...@@ -9,7 +9,7 @@ import time ...@@ -9,7 +9,7 @@ import time
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout, TransformerEncoderLayer, TransformerEncoder, \ from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout, TransformerEncoderLayer, TransformerEncoder, \
Embedding Embedding, GELU
from torch.optim import AdamW from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from torch.utils.data import TensorDataset, random_split from torch.utils.data import TensorDataset, random_split
...@@ -72,7 +72,8 @@ class Linker(Module): ...@@ -72,7 +72,8 @@ class Linker(Module):
self.dim_feedforward_transformer = int(Configuration.modelLinkerConfig['dim_feedforward_transformer']) self.dim_feedforward_transformer = int(Configuration.modelLinkerConfig['dim_feedforward_transformer'])
self.num_layers = int(Configuration.modelLinkerConfig['num_layers']) self.num_layers = int(Configuration.modelLinkerConfig['num_layers'])
# torch cat # torch cat
self.dim_cat_inter = int(Configuration.modelLinkerConfig['dim_cat_out']) dropout = float(Configuration.modelLinkerConfig['dropout'])
self.dim_cat_inter = int(Configuration.modelLinkerConfig['dim_cat_inter'])
self.dim_cat_out = int(Configuration.modelLinkerConfig['dim_cat_out']) self.dim_cat_out = int(Configuration.modelLinkerConfig['dim_cat_out'])
dim_intermediate_FFN = int(Configuration.modelLinkerConfig['dim_intermediate_FFN']) dim_intermediate_FFN = int(Configuration.modelLinkerConfig['dim_intermediate_FFN'])
dim_pre_sinkhorn_transfo = int(Configuration.modelLinkerConfig['dim_pre_sinkhorn_transfo']) dim_pre_sinkhorn_transfo = int(Configuration.modelLinkerConfig['dim_pre_sinkhorn_transfo'])
...@@ -107,7 +108,9 @@ class Linker(Module): ...@@ -107,7 +108,9 @@ class Linker(Module):
# Concatenation with word embedding # Concatenation with word embedding
dim_cat = dim_encoder + self.dim_emb_atom dim_cat = dim_encoder + self.dim_emb_atom
self.linker_encoder = Sequential( self.linker_encoder = Sequential(
FFN(dim_cat, self.dim_cat_inter, 0.1, d_out=self.dim_cat_out), Linear(dim_cat, self.dim_cat_out),
GELU(),
Dropout(dropout),
LayerNorm(self.dim_cat_out, eps=1e-8) LayerNorm(self.dim_cat_out, eps=1e-8)
) )
...@@ -142,11 +145,8 @@ class Linker(Module): ...@@ -142,11 +145,8 @@ class Linker(Module):
sentences_batch = df_axiom_links["X"].str.strip().tolist() sentences_batch = df_axiom_links["X"].str.strip().tolist()
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"]) atoms_batch, atoms_polarity_batch, num_atoms_per_word = get_GOAL(self.max_len_sentence, self.max_atoms_in_sentence, df_axiom_links)
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids( atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
list(map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch)))
num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence)
pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
......
import re import re
import pandas as pd
import regex import regex
import torch import torch
from torch.nn import Sequential, Linear, Dropout, GELU from torch.nn import Sequential, Linear, Dropout, GELU
...@@ -40,7 +42,6 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): ...@@ -40,7 +42,6 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
""" """
atoms_batch = get_atoms_links_batch(batch_axiom_links) atoms_batch = get_atoms_links_batch(batch_axiom_links)
atoms_batch = list(map(lambda sentence: sentence.split(" "), atoms_batch))
linking_plus_to_minus_all_types = [] linking_plus_to_minus_all_types = []
for atom_type in list(atom_map_redux.keys()): for atom_type in list(atom_map_redux.keys()):
# filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
...@@ -76,12 +77,12 @@ def category_to_atoms_axiom_links(category, categories_to_atoms): ...@@ -76,12 +77,12 @@ def category_to_atoms_axiom_links(category, categories_to_atoms):
word, cat = category.split(':') word, cat = category.split(':')
return category_to_atoms_axiom_links(cat, categories_to_atoms) return category_to_atoms_axiom_links(cat, categories_to_atoms)
elif True in res: elif True in res:
return " " + category return [category]
else: else:
category_cut = regex.match(regex_categories_axiom_links, category).groups() category_cut = regex.match(regex_categories_axiom_links, category).groups()
category_cut = [cat for cat in category_cut if cat is not None] category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut: for cat in category_cut:
categories_to_atoms += category_to_atoms_axiom_links(cat, "") categories_to_atoms += category_to_atoms_axiom_links(cat, [])
return categories_to_atoms return categories_to_atoms
...@@ -94,23 +95,22 @@ def get_atoms_links_batch(category_batch): ...@@ -94,23 +95,22 @@ def get_atoms_links_batch(category_batch):
""" """
batch = [] batch = []
for sentence in category_batch: for sentence in category_batch:
categories_to_atoms = "" categories_to_atoms = []
for category in sentence: for category in sentence:
if category != "let" and not category.startswith("GOAL:"): if category != "let" and not category.startswith("GOAL:"):
categories_to_atoms += category_to_atoms_axiom_links(category, "") categories_to_atoms += category_to_atoms_axiom_links(category, [])
categories_to_atoms += " [SEP]" categories_to_atoms.append("[SEP]")
categories_to_atoms = categories_to_atoms.lstrip()
elif category.startswith("GOAL:"): elif category.startswith("GOAL:"):
categories_to_atoms += category_to_atoms_axiom_links(category, "") categories_to_atoms = category_to_atoms_axiom_links(category, []) + categories_to_atoms
categories_to_atoms = categories_to_atoms.lstrip()
batch.append(categories_to_atoms) batch.append(categories_to_atoms)
return batch return batch
print("test to create links ", print("test to create links ",
get_axiom_links(20, torch.stack([torch.as_tensor( get_axiom_links(20, torch.stack([torch.as_tensor(
[False, True, False, False, False, True, False, True, False, False, True, False, False, False, True, False, [True, False, True, False, False, False, True, False, True, False,
False, True, False, True, False, False, True, False, False, False, True])]), False, True, False, False, False, True, False, False, True, False,
True, False, False, True, False, False, False, False, False, False])]),
[['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)', [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)',
'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']])) 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']]))
...@@ -132,12 +132,12 @@ def category_to_atoms(category, categories_to_atoms): ...@@ -132,12 +132,12 @@ def category_to_atoms(category, categories_to_atoms):
word, cat = category.split(':') word, cat = category.split(':')
return category_to_atoms(cat, categories_to_atoms) return category_to_atoms(cat, categories_to_atoms)
elif True in res: elif True in res:
return " " + category return [category]
else: else:
category_cut = regex.match(regex_categories, category).groups() category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None] category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut: for cat in category_cut:
categories_to_atoms += category_to_atoms(cat, "") categories_to_atoms += category_to_atoms(cat, [])
return categories_to_atoms return categories_to_atoms
...@@ -150,17 +150,17 @@ def get_atoms_batch(category_batch): ...@@ -150,17 +150,17 @@ def get_atoms_batch(category_batch):
""" """
batch = [] batch = []
for sentence in category_batch: for sentence in category_batch:
categories_to_atoms = "" categories_to_atoms = []
for category in sentence: for category in sentence:
if category != "let": if category != "let":
categories_to_atoms += category_to_atoms(category, "") categories_to_atoms += category_to_atoms(category, [])
categories_to_atoms += " [SEP]" categories_to_atoms.append("[SEP]")
categories_to_atoms = categories_to_atoms.lstrip()
batch.append(categories_to_atoms) batch.append(categories_to_atoms)
return batch return batch
print(" test for get atoms in categories on ['dr(0,s,np)', 'let']", get_atoms_batch([["dr(0,s,np)", "let"]])) print(" test for get atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']",
get_atoms_batch([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']]))
# endregion # endregion
...@@ -201,7 +201,7 @@ def get_num_atoms_batch(category_batch, max_len_sentence): ...@@ -201,7 +201,7 @@ def get_num_atoms_batch(category_batch, max_len_sentence):
""" """
batch = [] batch = []
for sentence in category_batch: for sentence in category_batch:
num_atoms_sentence = [] num_atoms_sentence = [0]
for category in sentence: for category in sentence:
num_atoms_in_word = category_to_num_atoms(category, 0) num_atoms_in_word = category_to_num_atoms(category, 0)
# add 1 because for word we have SEP at the end # add 1 because for word we have SEP at the end
...@@ -309,8 +309,7 @@ def find_pos_neg_idexes(atoms_batch): ...@@ -309,8 +309,7 @@ def find_pos_neg_idexes(atoms_batch):
return list_batch return list_batch
print( print(" test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np'] \n",
" test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']",
find_pos_neg_idexes([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', find_pos_neg_idexes([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)',
'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']])) 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']]))
...@@ -319,36 +318,32 @@ print( ...@@ -319,36 +318,32 @@ print(
# region get atoms and polarities with GOAL # region get atoms and polarities with GOAL
def get_GOAL(max_atoms_in_sentence, categories_batch): def get_GOAL(max_len_sentence, max_atoms_in_sentence, df_axiom_links):
categories_batch = df_axiom_links["Z"]
categories_with_goal = df_axiom_links["Y"]
polarities = find_pos_neg_idexes(categories_batch) polarities = find_pos_neg_idexes(categories_batch)
atoms_batch = get_atoms_batch(categories_batch) atoms_batch = get_atoms_batch(categories_batch)
atoms_batch_for_polarities = list( num_atoms_batch = get_num_atoms_batch(categories_batch, max_len_sentence)
map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch))
for s_idx in range(len(atoms_batch)): for s_idx in range(len(atoms_batch)):
for atom_type in list(atom_map_redux.keys()): goal = categories_with_goal[s_idx][-1]
list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i] polarities_goal = category_to_atoms_polarity(goal, True)
and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))] goal = re.search(r"(\w+)_\d+", goal).groups()[0]
list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i] atoms = category_to_atoms(goal, [])
and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))]
while len(list_minus) != len(list_plus): atoms_batch[s_idx] = atoms + atoms_batch[s_idx] # + ["[SEP]"]
if len(list_minus) > len(list_plus): polarities[s_idx] = polarities_goal + polarities[s_idx] # + False
atoms_batch[s_idx] += " " + atom_type num_atoms_batch[s_idx][0] += len(atoms) # +1
atoms_batch_for_polarities[s_idx].append(atom_type)
polarities[s_idx].append(True)
else:
atoms_batch[s_idx] += " " + atom_type
atoms_batch_for_polarities[s_idx].append(atom_type)
polarities[s_idx].append(False)
list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i]
and atoms_batch_for_polarities[s_idx][i] == atom_type]
list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i]
and atoms_batch_for_polarities[s_idx][i] == atom_type]
return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
max_len=max_atoms_in_sentence, padding_value=0) max_len=max_atoms_in_sentence, padding_value=0), num_atoms_batch
print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)", "s"]])) df_axiom_links = pd.DataFrame({"Z": [['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)',
'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']],
"Y": [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6',
'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9',
'GOAL:np_7']]})
print(" test for get GOAL ", get_GOAL(10, 30, df_axiom_links))
# endregion # endregion
...@@ -356,13 +351,11 @@ print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)", ...@@ -356,13 +351,11 @@ print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)",
# region get idx for pos and neg # region get idx for pos and neg
def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence): def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
atoms_batch_for_polarities = list(
map(lambda sentence: sentence.split(" "), atoms_batch))
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", bool(
atoms_batch_for_polarities[s_idx][i])) and re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
atoms_polarity_batch[s_idx][i]]) atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch_for_polarities)], for s_idx, sentence in enumerate(atoms_batch)],
max_len=max_atoms_in_one_type // 2, padding_value=-1) max_len=max_atoms_in_one_type // 2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())] for atom_type in list(atom_map_redux.keys())]
...@@ -370,24 +363,23 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_at ...@@ -370,24 +363,23 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_at
def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence): def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
atoms_batch_for_polarities = list(
map(lambda sentence: sentence.split(" "), atoms_batch))
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", bool(
atoms_batch_for_polarities[s_idx][i])) and not re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
atoms_polarity_batch[s_idx][i]]) not atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch_for_polarities)], for s_idx, sentence in enumerate(atoms_batch)],
max_len=max_atoms_in_one_type // 2, padding_value=-1) max_len=max_atoms_in_one_type // 2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())] for atom_type in list(atom_map_redux.keys())]
return torch.stack(pos_idx).permute(1, 0, 2) return torch.stack(pos_idx).permute(1, 0, 2)
print(" test for cut into pos neg on ['s np [SEP] s [SEP] np s s n n']", get_neg_idx(['s np [SEP] s [SEP] np s s n n'], print(" test for cut into pos neg on ['dr(0,s,np)', 's']",
torch.as_tensor( get_neg_idx([['s', 's', 'np', 's', 'np', '[SEP]', 's', '[SEP]']],
[[False, True, False, False, torch.as_tensor(
False, False, True, True, [[True, True, False, False,
False, True, True, False, False, False,
False, False]]), 10, 50)) False, False,
False, False]]), 10, 50))
# endregion # endregion
numpy==1.22.2 numpy==1.22.2
transformers==4.16.2
torch==1.9.0
huggingface-hub==0.4.0 huggingface-hub==0.4.0
pandas==1.4.1 pandas==1.4.1
sentencepiece Markdown==3.3.6
git+https://gitlab.irit.fr/pnria/global-helper/deepgrail-rnn/ packaging==21.3
\ No newline at end of file scikit-learn==1.0.2
scipy==1.8.0
sentencepiece==0.1.96
tensorflow==2.9.1
tensorboard==2.8.0
torch==1.11.0
tqdm==4.64.0
transformers==4.19.0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment