-
Caroline DE POURTALES authoredCaroline DE POURTALES authored
postprocessing.py 4.64 KiB
import re
import graphviz
import numpy as np
import regex
from Linker.atom_map import atom_map, atom_map_redux
regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
def recursive_linking(links, dot, category, parent_id, word_idx, depth,
polarity, compt_plus, compt_neg):
r"""
recursive linking between atoms inside a category
:param links:
:param dot:
:param category:
:param parent_id:
:param word_idx:
:param depth:
:param polarity:
:param compt_plus:
:param compt_neg:
:return:
"""
res = [(category == atom_type) for atom_type in atom_map.keys()]
if True in res:
polarity = not polarity
if polarity:
atoms_idx = compt_plus[category]
compt_plus[category] += 1
else:
idx_neg = compt_neg[category]
compt_neg[category] += 1
atoms_idx = np.where(links[atom_map_redux[category]] == idx_neg)[0][0]
atom_id = category + "_" + str(polarity) + "_" + str(atoms_idx)
dot.node(atom_id, category + " " + str("+" if polarity else "-"))
dot.edge(parent_id, atom_id)
else:
category_id = category + "_" + str(word_idx) + "_" + str(depth)
dot.node(category_id, category + " " + str("+" if polarity else "-"))
dot.edge(parent_id, category_id)
parent_id = category_id
if category.startswith("dr"):
categories_inside = regex.match(regex_categories, category).groups()
categories_inside = [cat for cat in categories_inside if cat is not None]
categories_inside = [categories_inside[0], categories_inside[1]]
polarities_inside = [polarity, not polarity]
# dl / p
elif category.startswith("dl") or category.startswith("p"):
categories_inside = regex.match(regex_categories, category).groups()
categories_inside = [cat for cat in categories_inside if cat is not None]
categories_inside = [categories_inside[0], categories_inside[1]]
polarities_inside = [not polarity, polarity]
# box / dia
elif category.startswith("box") or category.startswith("dia"):
categories_inside = regex.match(regex_categories, category).groups()
categories_inside = [cat for cat in categories_inside if cat is not None]
categories_inside = [categories_inside[0]]
polarities_inside = [polarity]
else:
categories_inside = []
polarities_inside = []
for cat_id in range(len(categories_inside)):
recursive_linking(links, dot, categories_inside[cat_id], parent_id, word_idx, depth + 1,
polarities_inside[cat_id], compt_plus,
compt_neg)
def draw_sentence_output(sentence, categories, links):
r"""
Drawing the prediction of a sentence when given categories and links predictions
:param sentence: list of words
:param categories: list of categories
:param links: links predicted, output of predict_with/without_categories
:return: dot source
"""
dot = graphviz.Graph('linking', comment='Axiom linking')
dot.graph_attr['rankdir'] = 'BT'
dot.graph_attr['splines'] = 'ortho'
dot.graph_attr['ordering'] = 'in'
compt_plus = {'cl_r': 0, 'pp': 0, 'n': 0, 'np': 0, 'cl_y': 0, 'txt': 0, 's': 0}
compt_neg = {'cl_r': 0, 'pp': 0, 'n': 0, 'np': 0, 'cl_y': 0, 'txt': 0, 's': 0}
last_word_id = ""
for word_idx in range(len(sentence)):
word = sentence[word_idx]
word_id = word + "_" + str(word_idx)
dot.node(word_id, word)
if word_idx > 0:
dot.edge(last_word_id, word_id, constraint="false", style="invis")
category = categories[word_idx]
polarity = True
parent_id = word_id
recursive_linking(links, dot, category, parent_id, word_idx, 0, polarity, compt_plus, compt_neg)
last_word_id = word_id
dot.attr('edge', color='red')
dot.attr('edge', style='dashed')
for atom_type in list(atom_map_redux.keys()):
for id in range(compt_plus[atom_type]):
atom_plus = atom_type + "_" + str(True) + "_" + str(id)
atom_moins = atom_type + "_" + str(False) + "_" + str(id)
dot.edge(atom_plus, atom_moins, constraint="false")
dot.render(format="svg", view=True)
return dot.source
sentence = ["Le", "chat", "est", "noir", "bleu"]
categories = ["dr(0,s,n)", "dl(0,s,n)", "dr(0,dl(0,n,np),n)", "dl(0,np,n)", "n"]
links = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 2, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
draw_sentence_output(sentence, categories, links)