Skip to content
Snippets Groups Projects
postprocessing.py 4.64 KiB
import re

import graphviz
import numpy as np
import regex
from Linker.atom_map import atom_map, atom_map_redux

regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'


def recursive_linking(links, dot, category, parent_id, word_idx, depth,
                      polarity, compt_plus, compt_neg):
    r"""
    recursive linking between atoms inside a category
    :param links:
    :param dot:
    :param category:
    :param parent_id:
    :param word_idx:
    :param depth:
    :param polarity:
    :param compt_plus:
    :param compt_neg:
    :return:
    """
    res = [(category == atom_type) for atom_type in atom_map.keys()]
    if True in res:
        polarity = not polarity
        if polarity:
            atoms_idx = compt_plus[category]
            compt_plus[category] += 1
        else:
            idx_neg = compt_neg[category]
            compt_neg[category] += 1
            atoms_idx = np.where(links[atom_map_redux[category]] == idx_neg)[0][0]
        atom_id = category + "_" + str(polarity) + "_" + str(atoms_idx)
        dot.node(atom_id, category + " " + str("+" if polarity else "-"))
        dot.edge(parent_id, atom_id)
    else:
        category_id = category + "_" + str(word_idx) + "_" + str(depth)
        dot.node(category_id, category + " " + str("+" if polarity else "-"))
        dot.edge(parent_id, category_id)
        parent_id = category_id

        if category.startswith("dr"):
            categories_inside = regex.match(regex_categories, category).groups()
            categories_inside = [cat for cat in categories_inside if cat is not None]
            categories_inside = [categories_inside[0], categories_inside[1]]
            polarities_inside = [polarity, not polarity]

        # dl / p
        elif category.startswith("dl") or category.startswith("p"):
            categories_inside = regex.match(regex_categories, category).groups()
            categories_inside = [cat for cat in categories_inside if cat is not None]
            categories_inside = [categories_inside[0], categories_inside[1]]
            polarities_inside = [not polarity, polarity]

        # box / dia
        elif category.startswith("box") or category.startswith("dia"):
            categories_inside = regex.match(regex_categories, category).groups()
            categories_inside = [cat for cat in categories_inside if cat is not None]
            categories_inside = [categories_inside[0]]
            polarities_inside = [polarity]

        else:
            categories_inside = []
            polarities_inside = []

        for cat_id in range(len(categories_inside)):
            recursive_linking(links, dot, categories_inside[cat_id], parent_id, word_idx, depth + 1,
                              polarities_inside[cat_id], compt_plus,
                              compt_neg)


def draw_sentence_output(sentence, categories, links):
    r"""
    Drawing the prediction of a sentence when given categories and links predictions
    :param sentence: list of words
    :param categories: list of categories
    :param links: links predicted, output of predict_with/without_categories
    :return: dot source
    """
    dot = graphviz.Graph('linking', comment='Axiom linking')
    dot.graph_attr['rankdir'] = 'BT'
    dot.graph_attr['splines'] = 'ortho'
    dot.graph_attr['ordering'] = 'in'

    compt_plus = {'cl_r': 0, 'pp': 0, 'n': 0, 'np': 0, 'cl_y': 0, 'txt': 0, 's': 0}
    compt_neg = {'cl_r': 0, 'pp': 0, 'n': 0, 'np': 0, 'cl_y': 0, 'txt': 0, 's': 0}
    last_word_id = ""
    for word_idx in range(len(sentence)):
        word = sentence[word_idx]
        word_id = word + "_" + str(word_idx)
        dot.node(word_id, word)
        if word_idx > 0:
            dot.edge(last_word_id, word_id, constraint="false", style="invis")

        category = categories[word_idx]
        polarity = True
        parent_id = word_id
        recursive_linking(links, dot, category, parent_id, word_idx, 0, polarity, compt_plus, compt_neg)
        last_word_id = word_id

    dot.attr('edge', color='red')
    dot.attr('edge', style='dashed')
    for atom_type in list(atom_map_redux.keys()):
        for id in range(compt_plus[atom_type]):
            atom_plus = atom_type + "_" + str(True) + "_" + str(id)
            atom_moins = atom_type + "_" + str(False) + "_" + str(id)
            dot.edge(atom_plus, atom_moins, constraint="false")

    dot.render(format="svg", view=True)
    return dot.source


sentence = ["Le", "chat", "est", "noir", "bleu"]
categories = ["dr(0,s,n)", "dl(0,s,n)", "dr(0,dl(0,n,np),n)", "dl(0,np,n)", "n"]
links = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 2, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
draw_sentence_output(sentence, categories, links)