Skip to content
Snippets Groups Projects
Commit b5a9eeab authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

update utils

parent b0e85309
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
......@@ -12,15 +12,15 @@ max_atoms_in_one_type=250
dim_encoder = 768
[MODEL_DECODER]
dim_decoder = 8
dim_decoder = 16
num_rnn_layers=1
dropout=0.1
teacher_forcing=0.05
[MODEL_LINKER]
nhead=1
nhead=4
dim_feedforward=246
dim_embedding_atoms=8
dim_embedding_atoms=16
dim_polarity_transfo=128
layer_norm_eps=1e-5
dropout=0.1
......
......@@ -17,7 +17,8 @@ from Linker.AtomTokenizer import AtomTokenizer
from Linker.MHA import AttentionDecoderLayer
from Linker.atom_map import atom_map
from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links
from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \
get_neg_encoding_for_s_idx
from Linker.eval import mesure_accuracy, SinkhornLoss
from utils import pad_sequence
......@@ -130,23 +131,17 @@ class Linker(Module):
link_weights = []
for atom_type in list(self.atom_map.keys())[:-1]:
pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
atoms_batch_tokenized[s_idx][i] == self.atom_map[
atom_type] and
atoms_polarity_batch[s_idx][i])] + [
torch.zeros(self.dim_embedding_atoms, device=self.device)]).to(self.device)
for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2).to(self.device)
neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
atoms_batch_tokenized[s_idx][i] == self.atom_map[
atom_type] and
not atoms_polarity_batch[s_idx][i])] + [
torch.zeros(self.dim_embedding_atoms, device=self.device)]).to(self.device)
for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2).to(self.device)
pos_encoding = pad_sequence(
[get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized,
atoms_polarity_batch, atom_type, s_idx)
for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2).to(self.device)
neg_encoding = pad_sequence(
[get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized,
atoms_polarity_batch, atom_type, s_idx)
for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2).to(self.device)
pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding)
......@@ -271,23 +266,17 @@ class Linker(Module):
link_weights = []
for atom_type in list(self.atom_map.keys())[:-1]:
pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and
atoms_tokenized[s_idx][i] == self.atom_map[
atom_type] and
polarities[s_idx][i])] + [
torch.zeros(self.dim_embedding_atoms, device=self.device)])
for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2).to(self.device)
neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and
atoms_tokenized[s_idx][i] == self.atom_map[
atom_type] and
not polarities[s_idx][i])] + [
torch.zeros(self.dim_embedding_atoms, device=self.device)])
for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2).to(self.device)
pos_encoding = pad_sequence(
[get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized,
polarities, atom_type, s_idx)
for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2).to(self.device)
neg_encoding = pad_sequence(
[get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized,
polarities, atom_type, s_idx)
for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2).to(self.device)
pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding)
......
from .Linker import Linker
\ No newline at end of file
from .Linker import Linker
from .atom_map import atom_map
from .AtomEmbedding import AtomEmbedding
\ No newline at end of file
......@@ -27,7 +27,7 @@ regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
#########################################################################################
################################ Liste des atoms avc _i########################################
################################ Liste des atoms avec _i ########################################
#########################################################################################
......@@ -72,7 +72,7 @@ def category_to_atoms_axiom_links(category, categories_to_atoms):
res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
if category.startswith("GOAL:"):
word, cat = category.split(':')
return [cat]
return category_to_atoms_axiom_links(cat, categories_to_atoms)
elif True in res:
return [category]
else:
......@@ -103,7 +103,6 @@ def get_atoms_links_batch(category_batch):
################################ Liste des atoms ########################################
#########################################################################################
def category_to_atoms(category, categories_to_atoms):
r"""
Args:
......@@ -115,8 +114,7 @@ def category_to_atoms(category, categories_to_atoms):
res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
if category.startswith("GOAL:"):
word, cat = category.split(':')
category = re.match(r'([a-zA-Z|_]+)_\d+', cat).group(1)
return [category]
return category_to_atoms(cat, categories_to_atoms)
elif True in res:
category = re.match(r'([a-zA-Z|_]+)_\d+', category).group(1)
return [category]
......@@ -158,78 +156,41 @@ def category_to_atoms_polarity(category, polarity):
"""
category_to_polarity = []
res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
# mot final
if category.startswith("GOAL:"):
category_to_polarity.append(True)
word, cat = category.split(':')
res = [bool(re.match(r'' + atom_type + "_\d+", cat)) for atom_type in atom_map.keys()]
if True in res:
category_to_polarity.append(True)
else:
category_to_polarity += category_to_atoms_polarity(cat, True)
# le mot a une category atomique
elif True in res or category.startswith("dia") or category.startswith("box"):
category_to_polarity.append(False)
category_to_polarity.append(not polarity)
# sinon c'est une formule longue
else:
# dr = /
if category.startswith("dr"):
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
left_side, right_side = category_cut[0], category_cut[1]
if polarity == True:
# for the left side : normal
res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
if True in res or left_side.startswith("dia") or left_side.startswith("box"):
category_to_polarity.append(False)
else:
category_to_polarity += category_to_atoms_polarity(left_side, True)
# for the right side : change polarity for next right formula
res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
if True in res or right_side.startswith("dia") or right_side.startswith("box"):
category_to_polarity.append(True)
else:
category_to_polarity += category_to_atoms_polarity(right_side, False)
else:
# for the left side
res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
if True in res or left_side.startswith("dia") or left_side.startswith("box"):
category_to_polarity.append(True)
else:
category_to_polarity += category_to_atoms_polarity(left_side, False)
# for the right side : change polarity for next right formula
res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
if True in res or right_side.startswith("dia") or right_side.startswith("box"):
category_to_polarity.append(False)
else:
category_to_polarity += category_to_atoms_polarity(right_side, True)
# for the left side
category_to_polarity += category_to_atoms_polarity(left_side, polarity)
# for the right side : change polarity for next right formula
category_to_polarity += category_to_atoms_polarity(right_side, not polarity)
# dl = \
elif category.startswith("dl"):
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
left_side, right_side = category_cut[0], category_cut[1]
if polarity == True:
# for the left side : change polarity
res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
if True in res or left_side.startswith("dia") or left_side.startswith("box"):
category_to_polarity.append(True)
else:
category_to_polarity += category_to_atoms_polarity(left_side, False)
# for the right side : normal
res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
if True in res or right_side.startswith("dia") or right_side.startswith("box"):
category_to_polarity.append(False)
else:
category_to_polarity += category_to_atoms_polarity(right_side, True)
else:
# for the left side
res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
if True in res or left_side.startswith("dia") or left_side.startswith("box"):
category_to_polarity.append(False)
else:
category_to_polarity += category_to_atoms_polarity(left_side, True)
# for the right side
res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
if True in res or right_side.startswith("dia") or right_side.startswith("box"):
category_to_polarity.append(True)
else:
category_to_polarity += category_to_atoms_polarity(right_side, False)
# for the left side
category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
# for the right side
category_to_polarity += category_to_atoms_polarity(right_side, polarity)
return category_to_polarity
......@@ -251,3 +212,32 @@ def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch):
list_batch.append(torch.as_tensor(list_atoms))
return pad_sequence([list_batch[i] for i in range(len(list_batch))],
max_len=max_atoms_in_sentence, padding_value=0)
#########################################################################################
################################ Prepare encoding ###############################################
#########################################################################################
def get_pos_encoding_for_s_idx(dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch,
atom_type, s_idx):
pos_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
atoms_batch_tokenized[s_idx][i] == atom_map[atom_type] and
atoms_polarity_batch[s_idx][i])]
if len(pos_encoding) == 0:
return torch.zeros(1, dim_embedding_atoms, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
else:
return torch.stack(pos_encoding)
def get_neg_encoding_for_s_idx(dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch,
atom_type, s_idx):
neg_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
atoms_batch_tokenized[s_idx][i] == atom_map[atom_type] and
not atoms_polarity_batch[s_idx][i])]
if len(neg_encoding) == 0:
return torch.zeros(1, dim_embedding_atoms, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
else:
return torch.stack(neg_encoding)
......@@ -6,7 +6,7 @@ from utils import read_csv_pgbar
torch.cuda.empty_cache()
batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 10
nb_sentences = batch_size * 200
epochs = int(Configuration.modelTrainingConfig['epoch'])
file_path_axiom_links = 'Datasets/aa1_links_dataset_links.csv'
......@@ -15,8 +15,6 @@ df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
sentences_batch = df_axiom_links["Sentences"].tolist()
supertagger = SuperTagger()
supertagger.load_weights("models/model_supertagger.pt")
sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
print("Linker")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment