Skip to content
Snippets Groups Projects
utils_linker.py 11.11 KiB
import re
import regex
import torch
from torch.nn import Sequential, Linear, Dropout, GELU
from torch.nn import Module
from Linker.atom_map import atom_map
from utils import pad_sequence


class FFN(Module):
    "Implements FFN equation."

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FFN, self).__init__()
        self.ffn = Sequential(
            Linear(d_model, d_ff, bias=False),
            GELU(),
            Dropout(dropout),
            Linear(d_ff, d_model, bias=False)
        )

    def forward(self, x):
        return self.ffn(x)


regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'


#########################################################################################
################################ Liste des atoms avc _i########################################
#########################################################################################


def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
    r"""
    Args:
        max_atoms_in_one_type : configuration
        atoms_polarity : (batch_size, max_atoms_in_sentence)
        batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms
    Returns:
        batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
    """
    atoms_batch = get_atoms_links_batch(batch_axiom_links)
    linking_plus_to_minus_all_types = []
    for atom_type in list(atom_map.keys())[:-1]:
        # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
        l_polarity_plus = [[x for i, x in enumerate(atoms_batch[s_idx]) if atoms_polarity[s_idx, i]
                            and bool(re.search(atom_type + "_", atoms_batch[s_idx][i]))] for s_idx in
                           range(len(atoms_batch))]
        l_polarity_minus = [[x for i, x in enumerate(atoms_batch[s_idx]) if not atoms_polarity[s_idx, i]
                             and bool(re.search(atom_type + "_", atoms_batch[s_idx][i]))] for s_idx in
                            range(len(atoms_batch))]

        linking_plus_to_minus = pad_sequence(
            [torch.as_tensor([l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else -1 for i, x in
                              enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
             for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2, padding_value=-1)

        linking_plus_to_minus_all_types.append(linking_plus_to_minus)

    return torch.stack(linking_plus_to_minus_all_types)


def category_to_atoms_axiom_links(category, categories_to_atoms):
    r"""
    Args:
        category : str of kind AtomCat | CategoryCat(dr or dl)
        categories_to_atoms : recursive list
    Returns :
        List of atoms inside the category in prefix order
    """
    res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
    if category.startswith("GOAL:"):
        word, cat = category.split(':')
        return [cat]
    elif True in res:
        return [category]
    else:
        category_cut = regex.match(regex_categories, category).groups()
        category_cut = [cat for cat in category_cut if cat is not None]
        for cat in category_cut:
            categories_to_atoms += category_to_atoms_axiom_links(cat, [])
        return categories_to_atoms


def get_atoms_links_batch(category_batch):
    r"""
    Args:
        category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
    Returns :
     (batch_size, max_atoms_in_sentence) flattened categories in prefix order
    """
    batch = []
    for sentence in category_batch:
        categories_to_atoms = []
        for category in sentence:
            categories_to_atoms += category_to_atoms_axiom_links(category, [])
        batch.append(categories_to_atoms)
    return batch


#########################################################################################
################################ Liste des atoms ########################################
#########################################################################################


def category_to_atoms(category, categories_to_atoms):
    r"""
    Args:
        category : str of kind AtomCat | CategoryCat(dr or dl)
        categories_to_atoms : recursive list
    Returns:
        List of atoms inside the category in prefix order
    """
    res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
    if category.startswith("GOAL:"):
        word, cat = category.split(':')
        category = re.match(r'([a-zA-Z|_]+)_\d+', cat).group(1)
        return [category]
    elif True in res:
        category = re.match(r'([a-zA-Z|_]+)_\d+', category).group(1)
        return [category]
    else:
        category_cut = regex.match(regex_categories, category).groups()
        category_cut = [cat for cat in category_cut if cat is not None]
        for cat in category_cut:
            categories_to_atoms += category_to_atoms(cat, [])
        return categories_to_atoms


def get_atoms_batch(category_batch):
    r"""
    Args:
        category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
    Returns:
     (batch_size, max_atoms_in_sentence) flattened categories in prefix order
    """
    batch = []
    for sentence in category_batch:
        categories_to_atoms = []
        for category in sentence:
            categories_to_atoms += category_to_atoms(category, [])
        batch.append(categories_to_atoms)
    return batch


#########################################################################################
################################ Polarity ###############################################
#########################################################################################

def category_to_atoms_polarity(category, polarity):
    r"""
    Args:
        category : str of kind AtomCat | CategoryCat(dr or dl)
        polarity : polarity according to recursivity
    Returns:
        Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
    """
    category_to_polarity = []
    res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
    if category.startswith("GOAL:"):
        category_to_polarity.append(True)
    elif True in res or category.startswith("dia") or category.startswith("box"):
        category_to_polarity.append(False)
    else:
        # dr = /
        if category.startswith("dr"):
            category_cut = regex.match(regex_categories, category).groups()
            category_cut = [cat for cat in category_cut if cat is not None]
            left_side, right_side = category_cut[0], category_cut[1]

            if polarity == True:
                # for the left side : normal
                res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
                if True in res or left_side.startswith("dia") or left_side.startswith("box"):
                    category_to_polarity.append(False)
                else:
                    category_to_polarity += category_to_atoms_polarity(left_side, True)
                # for the right side : change polarity for next right formula
                res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
                if True in res or right_side.startswith("dia") or right_side.startswith("box"):
                    category_to_polarity.append(True)
                else:
                    category_to_polarity += category_to_atoms_polarity(right_side, False)

            else:
                # for the left side
                res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
                if True in res or left_side.startswith("dia") or left_side.startswith("box"):
                    category_to_polarity.append(True)
                else:
                    category_to_polarity += category_to_atoms_polarity(left_side, False)
                # for the right side : change polarity for next right formula
                res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
                if True in res or right_side.startswith("dia") or right_side.startswith("box"):
                    category_to_polarity.append(False)
                else:
                    category_to_polarity += category_to_atoms_polarity(right_side, True)

        # dl = \
        elif category.startswith("dl"):
            category_cut = regex.match(regex_categories, category).groups()
            category_cut = [cat for cat in category_cut if cat is not None]
            left_side, right_side = category_cut[0], category_cut[1]

            if polarity == True:
                # for the left side : change polarity
                res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
                if True in res or left_side.startswith("dia") or left_side.startswith("box"):
                    category_to_polarity.append(True)
                else:
                    category_to_polarity += category_to_atoms_polarity(left_side, False)
                # for the right side : normal
                res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
                if True in res or right_side.startswith("dia") or right_side.startswith("box"):
                    category_to_polarity.append(False)
                else:
                    category_to_polarity += category_to_atoms_polarity(right_side, True)

            else:
                # for the left side
                res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
                if True in res or left_side.startswith("dia") or left_side.startswith("box"):
                    category_to_polarity.append(False)
                else:
                    category_to_polarity += category_to_atoms_polarity(left_side, True)
                # for the right side
                res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
                if True in res or right_side.startswith("dia") or right_side.startswith("box"):
                    category_to_polarity.append(True)
                else:
                    category_to_polarity += category_to_atoms_polarity(right_side, False)

    return category_to_polarity


def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch):
    r"""
    Args:
        max_atoms_in_sentence : configuration
        atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
    Returns:
        (batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order
    """
    list_batch = []
    for sentence in atoms_batch:
        list_atoms = []
        for category in sentence:
            for at in category_to_atoms_polarity(category, True):
                list_atoms.append(at)
        list_batch.append(torch.as_tensor(list_atoms))
    return pad_sequence([list_batch[i] for i in range(len(list_batch))],
                        max_len=max_atoms_in_sentence, padding_value=0)