Skip to content
Snippets Groups Projects
utils_linker.py 15.27 KiB
import re

import pandas as pd
import regex
import torch
from torch.nn import Sequential, Linear, Dropout, GELU
from torch.nn import Module

from Linker.atom_map import atom_map, atom_map_redux
from utils import pad_sequence


class FFN(Module):
    "Implements FFN equation."

    def __init__(self, d_model, d_ff, dropout=0.1, d_out=None):
        super(FFN, self).__init__()
        self.ffn = Sequential(
            Linear(d_model, d_ff, bias=False),
            GELU(),
            Dropout(dropout),
            Linear(d_ff, d_out if d_out is not None else d_model, bias=False)
        )

    def forward(self, x):
        return self.ffn(x)


################################ Regex ########################################
regex_categories_axiom_links = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'


# region get true axiom links
def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
    r"""
    Args:
        max_atoms_in_one_type : configuration
        atoms_polarity : (batch_size, max_atoms_in_sentence)
        batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms
    Returns:
        batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
    """
    atoms_batch = get_atoms_links_batch(batch_axiom_links)
    linking_plus_to_minus_all_types = []
    for atom_type in list(atom_map_redux.keys()):
        # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
        l_polarity_plus = [[x for i, x in enumerate(atoms_batch[s_idx]) if atoms_polarity[s_idx, i]
                            and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx
                           in range(len(atoms_batch))]
        l_polarity_minus = [[x for i, x in enumerate(atoms_batch[s_idx]) if not atoms_polarity[s_idx, i]
                             and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx
                            in range(len(atoms_batch))]

        linking_plus_to_minus = pad_sequence(
            [torch.as_tensor(
                [l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else -1
                 for i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
                for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2,
            padding_value=-1)

        linking_plus_to_minus_all_types.append(linking_plus_to_minus)

    return torch.stack(linking_plus_to_minus_all_types)


def category_to_atoms_axiom_links(category, categories_to_atoms):
    r"""
    Args:
        category : str of kind AtomCat | CategoryCat(dr or dl)
        categories_to_atoms : recursive list
    Returns :
        List of atoms inside the category in prefix order
    """
    res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
    if category.startswith("GOAL:"):
        word, cat = category.split(':')
        return category_to_atoms_axiom_links(cat, categories_to_atoms)
    elif True in res:
        return [category]
    else:
        category_cut = regex.match(regex_categories_axiom_links, category).groups()
        category_cut = [cat for cat in category_cut if cat is not None]
        for cat in category_cut:
            categories_to_atoms += category_to_atoms_axiom_links(cat, [])
        return categories_to_atoms


def get_atoms_links_batch(category_batch):
    r"""
    Args:
        category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
    Returns :
     (batch_size, max_atoms_in_sentence) flattened categories in prefix order
    """
    batch = []
    for sentence in category_batch:
        categories_to_atoms = []
        for category in sentence:
            if category != "let" and not category.startswith("GOAL:"):
                categories_to_atoms += category_to_atoms_axiom_links(category, [])
                categories_to_atoms.append("[SEP]")
            elif category.startswith("GOAL:"):
                categories_to_atoms = category_to_atoms_axiom_links(category, []) + categories_to_atoms
        batch.append(categories_to_atoms)
    return batch


print("test to create links ",
      get_axiom_links(20, torch.stack([torch.as_tensor(
          [True, False, True, False, False, False, True, False, True, False,
           False, True, False, False, False, True, False, False, True, False,
           True, False, False, True, False, False, False, False, False, False])]),
                      [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)',
                        'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']]))


# endregion

# region get atoms in sentence

def category_to_atoms(category, categories_to_atoms):
    r"""
    Args:
        category : str of kind AtomCat | CategoryCat(dr or dl)
        categories_to_atoms : recursive list
    Returns:
        List of atoms inside the category in prefix order
    """
    res = [(category == atom_type) for atom_type in atom_map.keys()]
    if category.startswith("GOAL:"):
        word, cat = category.split(':')
        return category_to_atoms(cat, categories_to_atoms)
    elif True in res:
        return [category]
    else:
        category_cut = regex.match(regex_categories, category).groups()
        category_cut = [cat for cat in category_cut if cat is not None]
        for cat in category_cut:
            categories_to_atoms += category_to_atoms(cat, [])
        return categories_to_atoms


def get_atoms_batch(category_batch):
    r"""
    Args:
        category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
    Returns:
     (batch_size, max_atoms_in_sentence) flattened categories in prefix order
    """
    batch = []
    for sentence in category_batch:
        categories_to_atoms = []
        for category in sentence:
            if category != "let":
                categories_to_atoms += category_to_atoms(category, [])
                categories_to_atoms.append("[SEP]")
        batch.append(categories_to_atoms)
    return batch


print(" test for get atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']",
      get_atoms_batch([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']]))


# endregion

# region calculate num atoms per category

def category_to_num_atoms(category, categories_to_atoms):
    r"""
    Args:
        category : str of kind AtomCat | CategoryCat(dr or dl)
        categories_to_atoms : recursive int
    Returns:
        List of atoms inside the category in prefix order
    """
    res = [(category == atom_type) for atom_type in atom_map.keys()]
    if category.startswith("GOAL:"):
        word, cat = category.split(':')
        return category_to_num_atoms(cat, 0)
    elif category == "let":
        return 0
    elif True in res:
        return 1
    else:
        category_cut = regex.match(regex_categories, category).groups()
        category_cut = [cat for cat in category_cut if cat is not None]
        for cat in category_cut:
            categories_to_atoms += category_to_num_atoms(cat, 0)
        return categories_to_atoms


def get_num_atoms_batch(category_batch, max_len_sentence):
    r"""
    Args:
        category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
        max_len_sentence : max_len_sentence parameter
    Returns:
     (batch_size, max_atoms_in_sentence) flattened categories in prefix order
    """
    batch = []
    for sentence in category_batch:
        num_atoms_sentence = [0]
        for category in sentence:
            num_atoms_in_word = category_to_num_atoms(category, 0)
            # add 1 because for word we have SEP at the end
            if category != "let":
                num_atoms_in_word += 1
            num_atoms_sentence.append(num_atoms_in_word)
        batch.append(torch.as_tensor(num_atoms_sentence))
    return pad_sequence(batch, max_len=max_len_sentence, padding_value=0)


print(" test for get number of atoms in categories on ['dr(0,s,np)', 'let']",
      get_num_atoms_batch([["dr(0,s,np)", "let"]], 10))


# endregion

# region get polarity

def category_to_atoms_polarity(category, polarity):
    r"""
    Args:
        category : str of kind AtomCat | CategoryCat(dr or dl)
        polarity : polarity according to recursivity
    Returns:
        Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
    """
    category_to_polarity = []
    res = [(category == atom_type) for atom_type in atom_map.keys()]
    # mot final
    if category.startswith("GOAL:"):
        word, cat = category.split(':')
        res = [bool(re.match(r'' + atom_type, cat)) for atom_type in atom_map.keys()]
        if True in res:
            category_to_polarity.append(True)
        else:
            category_to_polarity += category_to_atoms_polarity(cat, True)
    # le mot a une category atomique
    elif True in res:
        category_to_polarity.append(not polarity)
    # sinon c'est une formule longue
    else:
        # dr = /
        if category.startswith("dr"):
            category_cut = regex.match(regex_categories, category).groups()
            category_cut = [cat for cat in category_cut if cat is not None]
            left_side, right_side = category_cut[0], category_cut[1]
            # for the left side
            category_to_polarity += category_to_atoms_polarity(left_side, polarity)
            # for the right side : change polarity for next right formula
            category_to_polarity += category_to_atoms_polarity(right_side, not polarity)

        # dl = \
        elif category.startswith("dl"):
            category_cut = regex.match(regex_categories, category).groups()
            category_cut = [cat for cat in category_cut if cat is not None]
            left_side, right_side = category_cut[0], category_cut[1]
            # for the left side
            category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
            # for the right side
            category_to_polarity += category_to_atoms_polarity(right_side, polarity)

        # p
        elif category.startswith("p"):
            category_cut = regex.match(regex_categories, category).groups()
            category_cut = [cat for cat in category_cut if cat is not None]
            left_side, right_side = category_cut[0], category_cut[1]
            # for the left side
            category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
            # for the right side
            category_to_polarity += category_to_atoms_polarity(right_side, polarity)

        # box
        elif category.startswith("box"):
            category_cut = regex.match(regex_categories, category).groups()
            category_cut = [cat for cat in category_cut if cat is not None]
            category_to_polarity += category_to_atoms_polarity(category_cut[0], polarity)

        # dia
        elif category.startswith("dia"):
            category_cut = regex.match(regex_categories, category).groups()
            category_cut = [cat for cat in category_cut if cat is not None]
            category_to_polarity += category_to_atoms_polarity(category_cut[0], polarity)

    return category_to_polarity


def find_pos_neg_idexes(atoms_batch):
    r"""
    Args:
        atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
    Returns:
        (batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order
    """
    list_batch = []
    for sentence in atoms_batch:
        list_atoms = []
        for category in sentence:
            if category == "let":
                pass
            else:
                for at in category_to_atoms_polarity(category, True):
                    list_atoms.append(at)
                list_atoms.append(False)
        list_batch.append(list_atoms)
    return list_batch


print(" test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np'] \n",
    find_pos_neg_idexes([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)',
                          'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']]))


# endregion

# region get atoms and polarities with GOAL

def get_GOAL(max_len_sentence, df_axiom_links):
    categories_batch = df_axiom_links["Z"]
    categories_with_goal = df_axiom_links["Y"]
    polarities = find_pos_neg_idexes(categories_batch)
    atoms_batch = get_atoms_batch(categories_batch)
    num_atoms_batch = get_num_atoms_batch(categories_batch, max_len_sentence)
    for s_idx in range(len(atoms_batch)):
        goal = categories_with_goal[s_idx][-1]
        polarities_goal = category_to_atoms_polarity(goal, True)
        goal = re.search(r"(\w+)_\d+", goal).groups()[0]
        atoms = category_to_atoms(goal, [])

        atoms_batch[s_idx] = atoms + atoms_batch[s_idx]  # + ["[SEP]"]
        polarities[s_idx] = polarities_goal + polarities[s_idx]  # + False
        num_atoms_batch[s_idx][0] += len(atoms)  # +1

    return atoms_batch, polarities, num_atoms_batch


df_axiom_links = pd.DataFrame({"Z": [['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)',
                                      'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']],
                               "Y": [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6',
                                      'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9',
                                      'GOAL:np_7']]})
print(" test for get GOAL ", get_GOAL(10, df_axiom_links))


# endregion

# region get idx for pos and neg

def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
    pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
                                              bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
                                              atoms_polarity_batch[s_idx][i]])
                             for s_idx, sentence in enumerate(atoms_batch)],
                            max_len=max_atoms_in_one_type // 2, padding_value=-1)
               for atom_type in list(atom_map_redux.keys())]

    return torch.stack(pos_idx).permute(1, 0, 2)


def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
    pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
                                              bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
                                              not atoms_polarity_batch[s_idx][i]])
                             for s_idx, sentence in enumerate(atoms_batch)],
                            max_len=max_atoms_in_one_type // 2, padding_value=-1)
               for atom_type in list(atom_map_redux.keys())]

    return torch.stack(pos_idx).permute(1, 0, 2)


print(" test for cut into pos neg on ['dr(0,s,np)', 's']",
      get_neg_idx([['s', 's', 'np', 's', 'np', '[SEP]', 's', '[SEP]']],
                  torch.as_tensor(
                      [[True, True, False, False,
                        True, False, False, False,
                        False, False,
                        False, False]]), 10))

# endregion