-
Caroline DE POURTALES authoredCaroline DE POURTALES authored
utils_linker.py 13.09 KiB
import re
import regex
import torch
from torch.nn import Sequential, Linear, Dropout, GELU
from torch.nn import Module
from Linker.atom_map import atom_map, atom_map_redux
from utils import pad_sequence
class FFN(Module):
"Implements FFN equation."
def __init__(self, d_model, d_ff, dropout=0.1, d_out=None):
super(FFN, self).__init__()
self.ffn = Sequential(
Linear(d_model, d_ff, bias=False),
GELU(),
Dropout(dropout),
Linear(d_ff, d_out if d_out is not None else d_model, bias=False)
)
def forward(self, x):
return self.ffn(x)
#########################################################################################
################################ Regex ########################################
#########################################################################################
regex_categories_axiom_links = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
#########################################################################################
################################ Liste des atoms avec _i ########################################
#########################################################################################
def get_axiom_links(max_atoms_in_one_type, sub_atoms_type_list, atoms_polarity, batch_axiom_links):
r"""
Args:
max_atoms_in_one_type : configuration
sub_atoms_type_list : list of atom type to match
atoms_polarity : (batch_size, max_atoms_in_sentence)
batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms
Returns:
batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
"""
atoms_batch = get_atoms_links_batch(batch_axiom_links)
linking_plus_to_minus_all_types = []
for atom_type in sub_atoms_type_list:
# filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
l_polarity_plus = [[x for i, x in enumerate(atoms_batch[s_idx]) if atoms_polarity[s_idx, i]
and bool(re.search(atom_type + "_", atoms_batch[s_idx][i]))] for s_idx in
range(len(atoms_batch))]
l_polarity_minus = [[x for i, x in enumerate(atoms_batch[s_idx]) if not atoms_polarity[s_idx, i]
and bool(re.search(atom_type + "_", atoms_batch[s_idx][i]))] for s_idx in
range(len(atoms_batch))]
linking_plus_to_minus = pad_sequence(
[torch.as_tensor([l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else -1 for i, x in
enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2, padding_value=-1)
linking_plus_to_minus_all_types.append(linking_plus_to_minus)
return torch.stack(linking_plus_to_minus_all_types)
def category_to_atoms_axiom_links(category, categories_to_atoms):
r"""
Args:
category : str of kind AtomCat | CategoryCat(dr or dl)
categories_to_atoms : recursive list
Returns :
List of atoms inside the category in prefix order
"""
res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
if category.startswith("GOAL:"):
word, cat = category.split(':')
return category_to_atoms_axiom_links(cat, categories_to_atoms)
elif category == "let":
return []
elif True in res:
return [category]
else:
category_cut = regex.match(regex_categories_axiom_links, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut:
categories_to_atoms += category_to_atoms_axiom_links(cat, [])
return categories_to_atoms
def get_atoms_links_batch(category_batch):
r"""
Args:
category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
Returns :
(batch_size, max_atoms_in_sentence) flattened categories in prefix order
"""
batch = []
for sentence in category_batch:
categories_to_atoms = []
for category in sentence:
categories_to_atoms += category_to_atoms_axiom_links(category, [])
batch.append(categories_to_atoms)
return batch
#########################################################################################
################################ Liste des atoms ########################################
#########################################################################################
def category_to_atoms(category, categories_to_atoms):
r"""
Args:
category : str of kind AtomCat | CategoryCat(dr or dl)
categories_to_atoms : recursive list
Returns:
List of atoms inside the category in prefix order
"""
res = [(category == atom_type) for atom_type in atom_map.keys()]
if category.startswith("GOAL:"):
word, cat = category.split(':')
return category_to_atoms(cat, categories_to_atoms)
elif category == "let":
return []
elif True in res:
return [category]
else:
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut:
categories_to_atoms += category_to_atoms(cat, [])
return categories_to_atoms
def get_atoms_batch(category_batch):
r"""
Args:
category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
Returns:
(batch_size, max_atoms_in_sentence) flattened categories in prefix order
"""
batch = []
for sentence in category_batch:
categories_to_atoms = []
for category in sentence:
categories_to_atoms += category_to_atoms(category, [])
batch.append(categories_to_atoms)
return batch
#########################################################################################
################################ Polarity ###############################################
#########################################################################################
def category_to_atoms_polarity(category, polarity):
r"""
Args:
category : str of kind AtomCat | CategoryCat(dr or dl)
polarity : polarity according to recursivity
Returns:
Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
"""
category_to_polarity = []
res = [(category == atom_type) for atom_type in atom_map.keys()]
# mot final
if category.startswith("GOAL:"):
word, cat = category.split(':')
res = [bool(re.match(r'' + atom_type, cat)) for atom_type in atom_map.keys()]
if True in res:
category_to_polarity.append(True)
else:
category_to_polarity += category_to_atoms_polarity(cat, True)
elif category == "let":
pass
# le mot a une category atomique
elif True in res:
category_to_polarity.append(not polarity)
# sinon c'est une formule longue
else:
# dr = /
if category.startswith("dr"):
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
left_side, right_side = category_cut[0], category_cut[1]
# for the left side
category_to_polarity += category_to_atoms_polarity(left_side, polarity)
# for the right side : change polarity for next right formula
category_to_polarity += category_to_atoms_polarity(right_side, not polarity)
# dl = \
elif category.startswith("dl"):
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
left_side, right_side = category_cut[0], category_cut[1]
# for the left side
category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
# for the right side
category_to_polarity += category_to_atoms_polarity(right_side, polarity)
# p
elif category.startswith("p"):
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
left_side, right_side = category_cut[0], category_cut[1]
# for the left side
category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
# for the right side
category_to_polarity += category_to_atoms_polarity(right_side, polarity)
# box
elif category.startswith("box"):
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
category_to_polarity += category_to_atoms_polarity(category_cut[0], polarity)
# dia
elif category.startswith("dia"):
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
category_to_polarity += category_to_atoms_polarity(category_cut[0], polarity)
return category_to_polarity
def find_pos_neg_idexes(atoms_batch):
r"""
Args:
atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
Returns:
(batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order
"""
list_batch = []
for sentence in atoms_batch:
list_atoms = []
for category in sentence:
for at in category_to_atoms_polarity(category, True):
list_atoms.append(at)
list_batch.append(list_atoms)
return list_batch
#########################################################################################
################################ GOAL ###############################################
#########################################################################################
def get_GOAL(max_atoms_in_sentence, categories_batch):
polarities = find_pos_neg_idexes(categories_batch)
atoms_batch = get_atoms_batch(categories_batch)
for s_idx in range(len(atoms_batch)):
for atom_type in list(atom_map.keys()):
list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
and atoms_batch[s_idx][i] == atom_type]
list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
and atoms_batch[s_idx][i] == atom_type]
while len(list_minus) != len(list_plus):
if len(list_minus) > len(list_plus):
atoms_batch[s_idx].append(atom_type)
polarities[s_idx].append(True)
else:
atoms_batch[s_idx].append(atom_type)
polarities[s_idx].append(False)
list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
and atoms_batch[s_idx][i] == atom_type]
list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
and atoms_batch[s_idx][i] == atom_type]
return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
max_len=max_atoms_in_sentence, padding_value=0)
#########################################################################################
################################ Prepare encoding ###############################################
#########################################################################################
def get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, max_atoms_in_one_type):
inverse_atom_map = {v: k for k, v in atom_map.items()}
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(re.match(r"" + atom_type + "_?\w*", inverse_atom_map[int(atoms_batch_tokenized[s_idx][i])])) and
atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch_tokenized)], max_len=max_atoms_in_one_type//2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())]
return torch.stack(pos_idx).permute(1, 0, 2)
def get_neg_idx(atoms_batch_tokenized, atoms_polarity_batch, max_atoms_in_one_type):
inverse_atom_map = {v: k for k, v in atom_map.items()}
neg_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(re.match(r"" + atom_type + "_?\w*", inverse_atom_map[int(atoms_batch_tokenized[s_idx][i])])) and
not atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch_tokenized)], max_len=max_atoms_in_one_type//2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())]
return torch.stack(neg_idx).permute(1, 0, 2)