diff --git a/callbacks.py b/callbacks.py index 5ef62d601526f58dbf4d39f45ca698ee867d80ad..25a794b3eb7b330dda86eeb322dc052bd9eb29b9 100644 --- a/callbacks.py +++ b/callbacks.py @@ -46,10 +46,11 @@ def register_callbacks(page_home, page_course, page_application, app): Input('explanation_type', 'value'), Input('solver_sat', 'value'), Input('expl_choice', 'value'), + Input('cont_expl_choice', 'value'), prevent_initial_call=True ) def update_ml_type(value_ml_model, pretrained_model_contents, pretrained_model_filename, model_info, model_info_filename, \ - instance_contents, instance_filename, enum, xtype, solver, expl_choice): + instance_contents, instance_filename, enum, xtype, solver, expl_choice, cont_expl_choice): ctx = dash.callback_context if ctx.triggered: ihm_id = ctx.triggered[0]['prop_id'].split('.')[0] @@ -116,27 +117,36 @@ def register_callbacks(page_home, page_course, page_application, app): model_application.update_expl(expl_choice) return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation + # Choice of CxP to draw + elif ihm_id == 'cont_expl_choice' : + if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.enum<=0 or len(model_application.xtype)==0: + raise PreventUpdate + model_application.update_cont_expl(cont_expl_choice) + return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation @app.callback( Output('explanation', 'hidden'), - Output('navigate_label', 'hidden'), - Output('navigate_dropdown', 'hidden'), + Output('interaction_graph', 'hidden'), Output('expl_choice', 'options'), + Output('cont_expl_choice', 'options'), Input('explanation', 'children'), Input('explanation_type', 'value'), prevent_initial_call=True ) def layout_buttons_navigate_expls(explanation, explanation_type): if explanation is None or len(explanation_type)==0: - return True, True, True, {} + return True, True, {}, {} elif "AXp" not in explanation_type and "CXp" in explanation_type: - return False, True, True, {} + return False, True, {}, {} else : - options = {} + options_expls = {} + options_cont_expls = {} model_application = page_application.model for i in range (len(model_application.list_expls)): - options[str(model_application.list_expls[i])] = model_application.list_expls[i] - return False, False, False, options + options_expls[str(model_application.list_expls[i])] = model_application.list_expls[i] + for i in range (len(model_application.list_cont_expls)): + options_cont_expls[str(model_application.list_cont_expls[i])] = model_application.list_cont_expls[i] + return False, False, options_expls, options_cont_expls @app.callback( Output('choice_info_div', 'hidden'), diff --git a/callbacks_detached.py b/callbacks_detached.py deleted file mode 100644 index a8db16ff713479399892471d56224bd0cd9456b2..0000000000000000000000000000000000000000 --- a/callbacks_detached.py +++ /dev/null @@ -1,172 +0,0 @@ -import dash -import pandas as pd -from dash import Input, Output, State -from dash.dependencies import Input, Output, State -from dash.exceptions import PreventUpdate - -from utils import parse_contents_graph, parse_contents_instance, parse_contents_data - - -def register_callbacks(page_home, page_course, page_application, app): - page_list = ['home', 'course', 'application'] - - @app.callback( - Output('page-content', 'children'), - Input('url', 'pathname')) - def display_page(pathname): - if pathname == '/': - return page_home - if pathname == '/application': - return page_application.view.layout - if pathname == '/course': - return page_course - - @app.callback(Output('home-link', 'active'), - Output('course-link', 'active'), - Output('application-link', 'active'), - Input('url', 'pathname')) - def navbar_state(pathname): - active_link = ([pathname == f'/{i}' for i in page_list]) - return active_link[0], active_link[1], active_link[2] - - @app.callback( - Output('graph', 'children'), - Input('ml_model_choice', 'value'), - prevent_initial_call=True - ) - def update_ml_type(value_ml_model): - model_application = page_application.model - model_application.update_ml_model(value_ml_model) - return None - - @app.callback( - Output('pretrained_model_filename', 'children'), - Output('graph', 'children'), - Input('ml_pretrained_model_choice', 'contents'), - State('ml_pretrained_model_choice', 'filename'), - prevent_initial_call=True - ) - def update_ml_pretrained_model(pretrained_model_contents, pretrained_model_filename): - model_application = page_application.model - if model_application.ml_model is None : - raise PreventUpdate - graph = parse_contents_graph(pretrained_model_contents, pretrained_model_filename) - model_application.update_pretrained_model(graph) - if not model_application.add_info : - model_application.update_pretrained_model_layout() - return pretrained_model_filename, model_application.component.network - else : - return pretrained_model_filename, None - - @app.callback( - Output('info_filename', 'children'), - Output('graph', 'children'), - Input('model_info_choice', 'contents'), - State('model_info_choice', 'filename'), - prevent_initial_call=True - ) - def update_info_model(model_info, model_info_filename): - model_application = page_application.model - if model_application.ml_model is None : - raise PreventUpdate - model_info = parse_contents_data(model_info, model_info_filename) - model_application.update_pretrained_model_layout_with_info(model_info, model_info_filename) - return model_info_filename, model_application.component.network - - @app.callback( - Output('instance_filename', 'children'), - Output('graph', 'children'), - Output('explanation', 'children'), - Input('ml_instance_choice', 'contents'), - State('ml_instance_choice', 'filename'), - prevent_initial_call=True - ) - def update_instance(instance_contents, instance_filename): - model_application = page_application.model - if model_application.ml_model is None or model_application.pretrained_model is None or model_application.enum<=0 or model_application.xtype is None : - raise PreventUpdate - instance = parse_contents_instance(instance_contents, instance_filename) - model_application.update_instance(instance) - return instance_filename, model_application.component.network, model_application.component.explanation - - @app.callback( - Output('explanation', 'children'), - Input('number_explanations', 'value'), - prevent_initial_call=True - ) - def update_enum(enum): - model_application = page_application.model - if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.xtype is None: - raise PreventUpdate - model_application.update_enum(enum) - return model_application.component.explanation - - @app.callback( - Output('explanation', 'children'), - Input('explanation_type', 'value'), - prevent_initial_call=True - ) - def update_xtype(xtype): - model_application = page_application.model - if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.enum<=0 : - raise PreventUpdate - model_application.update_xtype(xtype) - return model_application.component.explanation - - @app.callback( - Output('explanation', 'children'), - Input('solver_sat', 'value'), - prevent_initial_call=True -) - def update_solver(solver): - model_application = page_application.model - if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.enum<=0 or len(model_application.xtype)==0: - raise PreventUpdate - model_application.update_solver(solver) - return model_application.component.explanation - - @app.callback( - Output('graph', 'children'), - Input('expl_choice', 'value'), - prevent_initial_call=True - ) - def update_expl_choice( expl_choice): - model_application = page_application.model - if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.enum<=0 or len(model_application.xtype)==0: - raise PreventUpdate - model_application.update_expl(expl_choice) - return model_application.component.network - - @app.callback( - Output('explanation', 'hidden'), - Output('navigate_label', 'hidden'), - Output('navigate_dropdown', 'hidden'), - Output('expl_choice', 'options'), - Input('explanation', 'children'), - Input('explanation_type', 'value'), - prevent_initial_call=True - ) - def layout_buttons_navigate_expls(explanation, explanation_type): - if explanation is None or len(explanation_type)==0: - return True, True, True, {} - elif "AXp" not in explanation_type and "CXp" in explanation_type: - return False, True, True, {} - else : - options = {} - model_application = page_application.model - for i in range (len(model_application.list_expls)): - options[str(model_application.list_expls[i])] = model_application.list_expls[i] - return False, False, False, options - - @app.callback( - Output('choice_info_div', 'hidden'), - Input('add_info_model_choice', 'on'), - prevent_initial_call=True - ) - def add_model_info(add_info_model_choice): - model_application = page_application.model - model_application.update_info_needed(add_info_model_choice) - if add_info_model_choice: - return False - else : - return True diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py index 8a56c32349fced19e8c03889bb26fdc176b79973..f9d34e32c731e7216577b94948964e72b6b54ec5 100644 --- a/pages/application/DecisionTree/DecisionTreeComponent.py +++ b/pages/application/DecisionTree/DecisionTreeComponent.py @@ -11,7 +11,8 @@ from pages.application.DecisionTree.utils.dtree import DecisionTree from pages.application.DecisionTree.utils.dtviz import (visualize, visualize_expl, - visualize_instance) + visualize_instance, + visualize_contrastive_expl) class DecisionTreeComponent(): @@ -135,7 +136,7 @@ class DecisionTreeComponent(): self.explanation.append(html.H4("Instance : \n")) self.explanation.append(html.P(str([str(instance[i]) for i in range (len(instance))]))) for k in explanation.keys() : - if k != "List of path explanation(s)": + if k != "List of path explanation(s)" and k!= "List of path contrastive explanation(s)" : if k in ["List of abductive explanation(s)","List of contrastive explanation(s)"] : self.explanation.append(html.H4(k)) for expl in explanation[k] : @@ -146,8 +147,9 @@ class DecisionTreeComponent(): self.explanation.append(html.P(k + explanation[k])) else : list_explanations_path = explanation["List of path explanation(s)"] + list_contrastive_explanations_path = explanation["List of path contrastive explanation(s)"] - return list_explanations_path + return list_explanations_path, list_contrastive_explanations_path def draw_explanation(self, instance, expl) : instance = self.translate_instance(instance) @@ -157,3 +159,12 @@ class DecisionTreeComponent(): style = {"width": "50%", "height": "80%", "background-color": "transparent"})]) + + def draw_contrastive_explanation(self, instance, cont_expl) : + instance = self.translate_instance(instance) + dot_source = visualize_contrastive_expl(self.dt, instance, cont_expl) + self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( + dot_source=dot_source, + style = {"width": "50%", + "height": "80%", + "background-color": "transparent"})]) diff --git a/pages/application/DecisionTree/utils/dtree.py b/pages/application/DecisionTree/utils/dtree.py index 79088c2cce1cbb0c4c6e65183195a4a110f357f6..c2775dfac8d7adbaca3ff30e8033b92ea2570c26 100644 --- a/pages/application/DecisionTree/utils/dtree.py +++ b/pages/application/DecisionTree/utils/dtree.py @@ -393,14 +393,18 @@ class DecisionTree(): done.append(target) return done + list_contrastive_expls = [] + to_hit = [set(s) for s in to_hit] to_hit.sort(key=lambda s: len(s)) expls = list(reduce(process_set, to_hit, [])) list_expls_str = [] explanation = {} for expl in expls: + list_contrastive_expls.append([self.fvmap[(p[0],1-p[1])] for p in sorted(expl, key=lambda p: p[0])]) list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(self.fvmap[p]) for p in sorted(expl, key=lambda p: p[0])]), term)) + explanation["List of path contrastive explanation(s)"] = list_contrastive_expls explanation["List of contrastive explanation(s)"] = list_expls_str explanation["Number of contrastive explanation(s) : "]=str(len(expls)) explanation["Minimal contrastive explanation : "]= str( min([len(e) for e in expls])) diff --git a/pages/application/DecisionTree/utils/dtviz.py b/pages/application/DecisionTree/utils/dtviz.py index 21ea63920338159b71e45c44a1cd2b772084bc9f..cf71d2bda925a3cfedcba5ddb641b80069bf5e92 100755 --- a/pages/application/DecisionTree/utils/dtviz.py +++ b/pages/application/DecisionTree/utils/dtviz.py @@ -12,7 +12,8 @@ #============================================================================== import getopt import pygraphviz - +import ast +import re # #============================================================================== def create_legend(g): @@ -22,6 +23,8 @@ def create_legend(g): legend.add_node("b", style = "invis") legend.add_node("c", style = "invis") legend.add_node("d", style = "invis") + legend.add_node("e", style = "invis") + legend.add_node("f", style = "invis") legend.add_edge("a","b") edge = legend.get_edge("a","b") @@ -34,7 +37,10 @@ def create_legend(g): edge.attr["color"] = "blue" edge.attr["style"] = "dashed" - + legend.add_edge("e","f") + edge = legend.get_edge("e","f") + edge.attr["label"] = "contrastive explanation" + edge.attr["color"] = "red" # #============================================================================== def visualize(dt): @@ -194,3 +200,64 @@ def visualize_expl(dt, instance, expl): # saving file g.layout(prog='dot') return(g.to_string()) + +#============================================================================== +def visualize_contrastive_expl(dt, instance, cont_expl): + """ + Visualize a DT with graphviz and plot the running instance. + """ + #path that follows the instance - colored in blue + path, term, depth = dt.execute(instance) + edges_instance = [] + for i in range (len(path)-1) : + edges_instance.append((path[i], path[i+1])) + edges_instance.append((path[-1],"term:"+term)) + + g = pygraphviz.AGraph(directed=True, strict=True) + g.edge_attr['dir'] = 'forward' + + g.graph_attr['rankdir'] = 'TB' + + # non-terminal nodes + for n in dt.nodes: + g.add_node(n, label=dt.feature_names[dt.nodes[n].feat]) + node = g.get_node(n) + node.attr['shape'] = 'circle' + node.attr['fontsize'] = 13 + + # terminal nodes + for n in dt.terms: + g.add_node(n, label=dt.terms[n]) + node = g.get_node(n) + node.attr['shape'] = 'square' + node.attr['fontsize'] = 13 + + # transitions + for n1 in dt.nodes: + for v in dt.nodes[n1].vals: + n2 = dt.nodes[n1].vals[v] + n2_type = g.get_node(n2).attr['shape'] + g.add_edge(n1, n2) + edge = g.get_edge(n1, n2) + if len(v) == 1: + edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])] + else: + edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) + + #instance path in dashed + if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): + edge.attr['style'] = 'dashed' + + for label in edge.attr['label'].split('\n'): + if label in cont_expl: + edge.attr['color'] = 'red' + + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 + + g.add_subgraph(name='legend') + create_legend(g) + + # saving file + g.layout(prog='dot') + return(g.to_string()) diff --git a/pages/application/NaiveBayes/NaiveBayesComponent.py b/pages/application/NaiveBayes/NaiveBayesComponent.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..98c7848a12a2e33bc270c8d30a0caab823f10c05 100644 --- a/pages/application/NaiveBayes/NaiveBayesComponent.py +++ b/pages/application/NaiveBayes/NaiveBayesComponent.py @@ -0,0 +1,57 @@ +from os import path +import base64 + +import dash_bootstrap_components as dbc +import numpy as np +from dash import dcc, html +import subprocess +import shlex + + + +class NaiveBayesComponent(): + + def __init__(self, model, type_model='SKL', info=None, type_info=''): + + #Conversion model + p=subprocess.Popen(['perl','pages/application/NaiveBayes/utils/cnbc2xlc.pl', model],stdout=subprocess.PIPE) + print(p.stdout.read()) + + self.naive_bayes = model + self.map_file = "" + + self.network = html.Div([]) + self.explanation = [] + + + def update_with_explicability(self, instance, enum, xtype, solver) : + + # Call explanation + p=subprocess.Popen(['perl','pages/application/NaiveBayes/utils/xpxlc.pl', self.naive_bayes, instance, self.map_file],stdout=subprocess.PIPE) + print(p.stdout.read()) + + self.explanation = [] + list_explanations_path=[] + explanation = {} + + self.network = html.Div([]) + + #Creating a clean and nice text component + #instance plotting + self.explanation.append(html.H4("Instance : \n")) + self.explanation.append(html.P(str([str(instance[i]) for i in range (len(instance))]))) + for k in explanation.keys() : + if k != "List of path explanation(s)" and k!= "List of path contrastive explanation(s)" : + if k in ["List of abductive explanation(s)","List of contrastive explanation(s)"] : + self.explanation.append(html.H4(k)) + for expl in explanation[k] : + self.explanation.append(html.Hr()) + self.explanation.append(html.P(expl)) + self.explanation.append(html.Hr()) + else : + self.explanation.append(html.P(k + explanation[k])) + else : + list_explanations_path = explanation["List of path explanation(s)"] + list_contrastive_explanations_path = explanation["List of path contrastive explanation(s)"] + + return list_explanations_path, list_contrastive_explanations_path diff --git a/pages/application/NaiveBayes/utils/Parsers.pm b/pages/application/NaiveBayes/utils/Parsers.pm new file mode 100644 index 0000000000000000000000000000000000000000..2fd493bc549790d254761c02222414a6a082eee2 --- /dev/null +++ b/pages/application/NaiveBayes/utils/Parsers.pm @@ -0,0 +1,319 @@ +package Parsers; + +use strict; +use warnings; + +use Data::Dumper; + +use POSIX qw( !assert ); +use Exporter; + +require Utils; # Must use require, to get INC updated +import Utils qw( &get_progname &get_progpath ); + +BEGIN { + @Parsers::ISA = ('Exporter'); + @Parsers::EXPORT_OK = + qw( &parse_xlc &parse_cnbc &parse_xmap + &parse_instance &parse_explanations + &parse_blc &parse_acc ); +} + +use constant F_ERR_MSG => + "Please check file name, existence, permissions, etc.\n"; +use constant HLPMAP => 1; +use constant CCAT_CH => '_'; +use constant CCHK => 0; + +if (CCHK) { + ## Uncomment to use assertions && debug messages + #use Carp::Assert; # Assertions are on. +} + + +# Parse XLC format +sub parse_xlc() +{ + my ($opts, $xlc, $fname) = @_; + + open(my $fh, "<$fname") || die "Unable to open file $fname. " . F_ERR_MSG; + my ($nc, $nr, $rmode) = (0, 0, 0); + while(<$fh>) { + chomp; + next if m/^\s*c\s+$/; + if ($rmode == 0) { # Read number of features + m/^\s*(\d+)\s*$/ || die "Unable to match: $_\n"; + ($xlc->{NV}, $rmode) = ($1, 1); + } + elsif ($rmode == 1) { # Read w0 + m/^\s*(\-?\d+\.?\d*)\s*$/ || die "Unable to match: $_\n"; + ($xlc->{W0}, $rmode) = ($1, 2); + } + elsif ($rmode == 2) { # Read number of real-valued features + m/^\s*(\d+)\s*$/ || die "Unable to match: $_\n"; + ($xlc->{NReal}, $rmode) = ($1, 3); + if ($xlc->{NReal} == 0) { $rmode = 4; } + } + elsif ($rmode == 3) { # Read real-valued coefficients + m/^\s*(\-?\d+\.?\d*)\s*$/ || die "Unable to match: $_\n"; + push @{$xlc->{RVs}}, $1; + if (++$nr == $xlc->{NReal}) { ($nr, $rmode) = (0, 4); } + } + elsif ($rmode == 4) { # Read number of categorical features + m/^\s*(\d+)\s*$/ || die "Unable to match: $_\n"; + ($xlc->{NCat}, $rmode) = ($1, 5); + } + elsif ($rmode == 5) { # Read domains and weights of cat. features + my $cvi = "CVs$nc"; + @{$xlc->{$cvi}} = split(/ +/); + push @{$xlc->{CDs}}, shift @{$xlc->{$cvi}}; + if (++$nc == $xlc->{NCat}) { $rmode = 6; } + } + else { die "Invalid state with input: $_\n"; } + } + close($fh); +} + + +# Parse map file +sub parse_xmap() +{ + my ($opts, $xmap, $fname) = @_; + + open(my $fh, "<$fname") || die "Unable to open file $fname. " . F_ERR_MSG; + my ($cc, $nv, $nc, $nr, $rmode) = (0, 0, 0, 0, 0); + while(<$fh>) { + chomp; + next if m/^\s*c\s+$/; + if ($rmode == 0) { # Read number of classes + m/^\s*(\d+)\s*$/ || die "Unable to match: $_\n"; + ($xmap->{NC}, $rmode, $cc) = ($1, 1, 0); + if ($xmap->{NC} == 0) { $rmode = 2; } + } + elsif ($rmode == 1) { # Read class name maps + my @toks = split(/ +/); + my $cid = shift @toks; + ${$xmap->{ClMap}}[$cid] = join(CCAT_CH, @toks); + if (++$cc == $xmap->{NC}) { $rmode = 2; } + } + elsif ($rmode == 2) { # Read number of features + m/^\s*(\d+)\s*$/ || die "Unable to match \@ $rmode: $_\n"; + ($xmap->{NV}, $rmode) = ($1, 3); + } + elsif ($rmode == 3) { # Read number of real-valued features + m/^\s*(\d+)\s*$/ || die "Unable to match \@ $rmode: $_\n"; + ($xmap->{NReal}, $rmode, $nr) = ($1, 4, 0); + if ($xmap->{NReal} == 0) { $rmode = 5; } + } + elsif ($rmode == 4) { # Read map of real-value features + my @toks = split(/ +/); + my $rid = shift @toks; + ${$xmap->{VMap}}[$rid] = join(CCAT_CH, @toks); + if (++$nr == $xmap->{NReal}) { $rmode = 5; } + } + elsif ($rmode == 5) { # Read number of categorical features + m/^\s*(\d+)\s*$/ || die "Unable to match \@ $rmode: $_\n"; + ($xmap->{NCat}, $rmode, $nc) = ($1, 6, $nr); + } + elsif ($rmode == 6) { # Read categorical feature + my @toks = split(/ +/); + my $cid = shift @toks; + if (!HLPMAP) { + ${$xmap->{VMap}}[$cid] = join(CCAT_CH, @toks); } + else { + my ($sch, $ech, $jch) = ('', '', ''); + if ($#toks > 0) { ($sch, $ech, $jch) = ('\'', '\'', ' '); } + ${$xmap->{VMap}}[$cid] = $sch . join($jch, @toks) . $ech; + } + $rmode = 7; + if (CCHK) { assert($cid == $nc, "Invalid categorical ID"); } + } + elsif ($rmode == 7) { # Read domain size of current feature + m/^\s*(\d+)\s*$/ || die "Unable to match \@ $rmode: $_\n"; + ($xmap->{CDs}->{$nc}, $rmode, $nv) = ($1, 8, 0); + } + elsif ($rmode == 8) { # Read values of categorical feature + my @toks = split(/ +/); + my $vid = shift @toks; + if (!HLPMAP) { + ${$xmap->{CMap}->{$nc}}[$vid] = join(CCAT_CH, @toks); } + else { + my ($repl, $sch, $ech, $jch) = (0, '', '', ''); + for (my $i=0; $i<=$#toks; ++$i) { + if ($toks[$i] =~ m/${$xmap->{VMap}}[$nc]/) { + $toks[$i] =~ s/${$xmap->{VMap}}[$nc]/\?\?/g; + $repl = 1; + } + } + if ($#toks > 0 && !$repl) { ($sch,$ech,$jch)=('\'','\'',' '); } + ${$xmap->{CMap}->{$nc}}[$vid] = $sch . join($jch, @toks) . $ech; + } + if (++$nv == $xmap->{CDs}->{$nc}) { + if (++$nc == $xmap->{NReal}+$xmap->{NCat}) { $rmode = 9; } + else { $rmode = 6; } + } + } + else { die "Invalid state with input \@ $rmode: $_\n"; } + } + close($fh); +} + + +# Parse CNBC format -- currently hard-coded for 2 classes +sub parse_cnbc() +{ + my ($opts, $cnbc, $fname) = @_; + + open(my $fh, "<$fname") || die "Unable to open file $fname. " . F_ERR_MSG; + my ($cc, $cv, $pol, $rmode) = (0, 0, 0, 0); + while(<$fh>) { + chomp; + if ($rmode == 0) { # Read number of classes + m/^\s*(\d+)\s*$/ || die "Unable to match: $_\n"; + ($cnbc->{NC}, $rmode, $cc) = ($1, 1, 0); + } + elsif ($rmode == 1) { # Read priors + m/^\s*(\-?\d+\.?\d*)\s*$/ || die "Unable to match: $_\n"; + push @{$cnbc->{Prior}}, $1; + if (++$cc == $cnbc->{NC}) { $rmode = 2; } + } + elsif ($rmode == 2) { # Read number of features + m/^\s*(\d+)\s*$/ || die "Unable to match: $_\n"; + ($cnbc->{NV}, $cv, $rmode) = ($1, 0, 3); + } + elsif ($rmode == 3) { # Read domain size of feature + my $cpt = "CPT$cv"; + if ($cv == $cnbc->{NV}) { die "Too many features specified?\n"; } + m/^\s*(\d+)\s*$/ || die "Unable to match: $_\n"; + ($cnbc->{$cpt}->{D}, $cc, $rmode) = ($1, 0, 4); + } + elsif ($rmode == 4) { # Read CPT for feature + my $cpt = "CPT$cv"; + my $ccl = "C$cc"; + my @probs = split(/ +/); + if ($#probs+1 != $cnbc->{$cpt}->{D}) { die "Invalid CPT def\n"; } + for (my $i=0; $i<=$#probs; ++$i) { + $probs[$i] =~ m/(\-?\d+\.?\d*)/ || die "Unable to match: $_\n"; + push @{$cnbc->{$cpt}->{$ccl}}, $probs[$i]; + } + if (++$cc == $cnbc->{NC}) { + ($cv, $cc, $rmode) = ($cv+1, 0, 3); # Move to next feature + } + } else { die "Unexpected read mode in CNBC file\n"; } + } + close($fh); +} + + +# Parse BLC format +sub parse_blc() +{ + my ($opts, $blc, $fname) = @_; + open(my $fh, "<$fname") || die "Unable to open file $fname. " . F_ERR_MSG; + my ($rmode, $cnt) = (0, 0); + while(<$fh>) { + next if m/^\s*$/ || m/^c\s+/; + chomp; + if ($rmode == 0) { + m/\s*(\d+)\s*$/ || die "Unable to match: $_\n"; + ($blc->{NV}, $rmode) = ($1, 1); + } + elsif ($rmode == 1) { + if ($cnt == $blc->{NV}+1) { + die "Too many lines in BLC description??\n"; } + m/^\s*(\-?\d+\.?\d*)\s*$/ || die "Unable to match: $_\n"; + ${$blc->{Ws}}[$cnt++] = $1; + } + } + close($fh); +} + +# Parse ACC format +sub parse_acc() +{ + my ($opts, $acc, $fname) = @_; + + open(my $fh, "<$fname") || die "Unable to open file $fname. " . F_ERR_MSG; + my ($cc, $cv, $pol, $rmode) = (0, 0, 0, 0); + while(<$fh>) { + next if m/^\s*$/ || m/^c\s+/; + chomp; + if ($rmode == 0) { + m/\s*(\d)\s*$/ || die "Unable to match: $_\n"; + ($acc->{NC}, $rmode) = ($1, 1); + } + elsif ($rmode == 1) { + m/\s*(\d+)\s*$/ || die "Unable to match: $_\n"; + ($acc->{NV}, $rmode) = ($1, 2); + } + elsif ($rmode == 2) { + my $class = "C$cc"; + m/^\s*(\-?\d+\.?\d*)\s*$/ || die "Unable to match: $_\n"; + $acc->{VV}->{$class}->{W0} = $1; + $rmode = 3; + } + elsif ($rmode == 3) { + my $class = "C$cc"; + my $polarity = "P$pol"; + m/^\s*(\-?\d+\.?\d*)\s*$/ || die "Unable to match: $_\n"; + ${$acc->{VV}->{$class}->{$polarity}}[$cv] = $1; + $pol = 1 - $pol; + if ($pol == 0) { $cv++; } + if ($cv == $acc->{NV}) { + ($cc, $cv, $pol) = ($cc+1, 0, 0); + if ($cc == $acc->{NC}) { last; } + $rmode = 2; + } + } + } + close($fh); +} + + +# Parse instance format +sub parse_instance() +{ + my ($opts, $inst, $fname) = @_; + + open(my $fh, "<$fname") || die "Unable to open file $fname. " . F_ERR_MSG; + my ($cnt, $rmode) = (0, 0); + while(<$fh>) { + next if m/^\s*$/ || m/^c\s+/; + chomp; + if ($rmode == 0) { + m/\s*(\d+)\s*$/ || die "Unable to match: $_\n"; + ($inst->{NV}, $rmode) = ($1, 1); + } + elsif ($rmode == 1) { + m/\s*(\d+)\s*$/ || die "Unable to match: $_\n"; + ${$inst->{E}}[$cnt++] = $1; + if ($cnt == $inst->{NV}) { $rmode = 2; } + } + elsif ($rmode == 2) { + m/\s*(\d+)\s*$/ || die "Unable to match: $_\n"; + $inst->{C} = $1; + } + } + close($fh); +} + +# Parse explanations +sub parse_explanations() +{ + my ($fname, $xpl) = @_; + open(my $fh, "<$fname") || die "Unable to open file $fname. " . F_ERR_MSG; + while(<$fh>) { + next if m/^\s*$/ || m/^c\s+/; + chomp; + my @lits = split(/ +/); + shift @lits; # Drop 'Expl: ' + push @{$xpl->{Expl}}, \@lits; + } + close($fh); +} + + +END { +} + +1; # to ensure that the 'require' or 'use' succeeds diff --git a/pages/application/NaiveBayes/utils/Utils.pm b/pages/application/NaiveBayes/utils/Utils.pm new file mode 100644 index 0000000000000000000000000000000000000000..d694e44dec936f672410a803ea8bda1eaf08a93d --- /dev/null +++ b/pages/application/NaiveBayes/utils/Utils.pm @@ -0,0 +1,114 @@ +package Utils; + +use strict; +use warnings; + +use Data::Dumper; + +use POSIX; +use Exporter(); +use Sys::Hostname; + +BEGIN { + @Utils::ISA = ('Exporter'); + @Utils::EXPORT_OK = qw( &get_progname &get_progpath &round &SIG_handler ); +} + + +#------------------------------------------------------------------------------# +# Execution path handling +#------------------------------------------------------------------------------# + +sub get_progname() { + my @progname_toks = split(/\//, $0); + my $progname = $progname_toks[$#progname_toks]; + #print "$progname\n"; + return $progname; +} + +sub get_progpath() { + my @progname_toks = split(/\//, $0); + pop @progname_toks; + my $progpath = join('/', @progname_toks); + if ($progpath eq '') { $progpath = '\.\/'; } + #print "Prog Path: $progpath\n"; #exit; + return $progpath; +} + +sub get_hostname() { + my $full_host_name = &Sys::Hostname::hostname(); + $full_host_name =~ m/(\w+)\.?/; + my $rhname = $1; + #print "|$hostname|\n"; exit; + return $rhname; +} + +sub resolve_inc() { # Kept here as a template; need a copy in each script... + my ($cref, $pmname) = @_; + my @progname_toks = split(/\//, $0); + pop @progname_toks; + my $progpath = join('/', @progname_toks); + my $fullname = $progpath . '/' . $pmname; + my $fh; + open($fh, "<$fullname") || die "non-existing file: $pmname\n"; + return $fh; +} + + +#------------------------------------------------------------------------------# +# Signal handling utilities +#------------------------------------------------------------------------------# + +sub register_handlers() +{ + $SIG{'INT'} = 'Utils::INT_handler'; + $SIG{'TERM'} = 'Utils::INT_handler'; + $SIG{'ABRT'} = 'Utils::SIG_handler'; + $SIG{'SEGV'} = 'Utils::SIG_handler'; + $SIG{'BUS'} = 'Utils::SIG_handler'; + $SIG{'QUIT'} = 'Utils::SIG_handler'; + $SIG{'XCPU'} = 'Utils::SIG_handler'; +} + +my @args = (); +my @callback = (); + +sub push_arg() +{ + push @args, shift; +} + +sub push_callback() +{ + push @callback, shift; +} + +sub SIG_handler() +{ + &Utils::INT_handler(); +} + +sub INT_handler() +{ + # call any declared callbacks, e.g. to prints stats, summaries, etc. + print "\nReceived system signal. Cleaning up & terminating...\n"; + foreach my $cback (@callback) { + &{$cback}(\@args); + } + exit 20; # 20 denotes resources exceeded condition (see below) +} + + +#------------------------------------------------------------------------------# +# Useful utils +#------------------------------------------------------------------------------# + +sub round() { + my ($rval) = @_; + return int($rval + 0.5); +} + +END { +} + +1; # to ensure that the 'require' or 'use' succeeds diff --git a/pages/application/NaiveBayes/utils/Writers.pm b/pages/application/NaiveBayes/utils/Writers.pm new file mode 100644 index 0000000000000000000000000000000000000000..e281cec9d853630b4b6e7c8c5486f9d37974c693 --- /dev/null +++ b/pages/application/NaiveBayes/utils/Writers.pm @@ -0,0 +1,42 @@ +package Writers; + +use strict; +use warnings; + +use Data::Dumper; + +use POSIX; +use Exporter; + +require Utils; # Must use require, to get INC updated +import Utils qw( &get_progname &get_progpath ); + +BEGIN { + @Writers::ISA = ('Exporter'); + @Writers::EXPORT_OK = qw( &write_xlc ); +} + + +# Export XLC format +sub write_xlc() +{ + my ($opts, $xlc) = @_; + print("$xlc->{NV}\n"); + print("$xlc->{W0}\n"); + print("$xlc->{NReal}\n"); + for (my $i=0; $i<$xlc->{NReal}; ++$i) { + print("${$xlc->{RVs}}[$i]\n"); + } + print("$xlc->{NCat}\n"); + for (my $i=0; $i<$xlc->{NCat}; ++$i) { + my $cvi = "CVs$i"; + print("${$xlc->{CDs}}[$i] "); + print("@{$xlc->{$cvi}}\n"); + } +} + + +END { +} + +1; # to ensure that the 'require' or 'use' succeeds diff --git a/pages/application/NaiveBayes/utils/cnbc2xlc.pl b/pages/application/NaiveBayes/utils/cnbc2xlc.pl new file mode 100755 index 0000000000000000000000000000000000000000..b0383d086c28cc4dc7ad51c7361f71540963eb4a --- /dev/null +++ b/pages/application/NaiveBayes/utils/cnbc2xlc.pl @@ -0,0 +1,247 @@ +#!/usr/bin/env perl + +## Tool for translating the probabilities of an CNBC into a +## sequence of non-negative weights which are then represented +## in the XLC format. +## Script specifically assumes *2* classes + +push @INC, \&resolve_inc; + +use strict; +use warnings; +use Data::Dumper; +use Getopt::Std; + +require Parsers; +import Parsers qw( parse_cnbc ); + +require Writers; +import Writers qw( write_xlc ); + +use constant DBG => 0; ## Also, comment out unused 'uses' +use constant CHK => 0; + +my $f_err_msg = "Please check file name, existence, permissions, etc.\n"; + +# 0. Read command line arguments +my %opts = (); +&read_opts(\%opts); + + +if ((CHK || DBG) && (defined($opts{k}) || defined($opts{d}))) { + ## Uncomment to use assertions && debug messages + #use Carp::Assert; # Assertions are on. + #if (DBG && $opts{d}) { + # use Data::Dumper; + #} +} +if (defined($opts{o})) { + open ($opts{FH}, '>', $opts{o}); + select($opts{FH}); +} + + +# 1. Data structures +my %cnbc = (); +my %xlc = (); +my $mval = 0; +my $tval = 0; + +# 2. Read ML model (definition of (C)NBC in CNBC format) +&parse_cnbc(\%opts, \%cnbc, $opts{f}); +if ($opts{d}) { warn Data::Dumper->Dump([ \%cnbc ], [ qw(cnbc) ]); } + +# 3. Translate CNBC weights (i.e. probs) into CNBC weights (i.e. additive & >=0) +&process_weights(\%opts, \%cnbc); +if ($opts{d}) { warn Data::Dumper->Dump([ \%cnbc ], [ qw(cnbc) ]); } + +#4. Reduce CNBC (w/ weights) into XLC +&reduce_cnbc_xlc(\%opts, \%cnbc, \%xlc); +if ($opts{d}) { warn Data::Dumper->Dump([ \%xlc ], [ qw(xlc) ]); } + +# 4. Print ML model in ACC format +&write_xlc(\%opts, \%xlc); + +1; + + +# Core functions + +# Goal is to apply a translation to the prob values +sub process_weights() +{ + my ($opts, $cnbc) = @_; + if (CHK && $opts->{k}) { + assert($cnbc->{NC}==2, "Cannot handle $cnbc->{NC} classes\n"); + } + + # 1. First traversal: compute & sum logarithms and flag 0 probs + my ($hasp0, $sumlogs, $minv, $logv) = (0, 0, 0, 0); + for(my $i=0; $i<=$#{$cnbc->{Prior}}; ++$i) { + if (${$cnbc->{Prior}}[$i] == 0) { $hasp0 = 1; } + else { + $logv = log(${$cnbc->{Prior}}[$i]); + $sumlogs += $logv; + ${$cnbc->{Prior}}[$i] = $logv; + if ($logv < $minv) { $minv = $logv; } + } + } + for(my $j=0; $j<$cnbc->{NV}; ++$j) { + my $cpt = "CPT$j"; + for(my $i=0; $i<=$#{$cnbc->{Prior}}; ++$i) { + my $ccl = "C$i"; + for(my $k=0; $k<$cnbc->{$cpt}->{D}; ++$k) { + if (${$cnbc->{$cpt}->{$ccl}}[$k] == 0) { $hasp0 = 1; } + else { + $logv = log(${$cnbc->{$cpt}->{$ccl}}[$k]); + $sumlogs += $logv; + ${$cnbc->{$cpt}->{$ccl}}[$k] = $logv; + if ($logv < $minv) { $minv = $logv; } + } + } + } + } + $mval = $sumlogs - 1; + $tval = ($hasp0) ? -$mval : -$minv; + # 2. Second traversal: update 0 probs, offset weights by T + for(my $i=0; $i<=$#{$cnbc->{Prior}}; ++$i) { + if (${$cnbc->{Prior}}[$i] == 0) { + ${$cnbc->{Prior}}[$i] = $mval; + } + ${$cnbc->{Prior}}[$i] += $tval; + } + for(my $j=0; $j<$cnbc->{NV}; ++$j) { + my $cpt = "CPT$j"; + for(my $i=0; $i<=$#{$cnbc->{Prior}}; ++$i) { + my $ccl = "C$i"; + for(my $k=0; $k<$cnbc->{$cpt}->{D}; ++$k) { + if (${$cnbc->{$cpt}->{$ccl}}[$k] == 0) { + ${$cnbc->{$cpt}->{$ccl}}[$k] = $mval; + } + ${$cnbc->{$cpt}->{$ccl}}[$k] += $tval; + } + } + } + if ($opts->{d}) { warn Data::Dumper->Dump([ $cnbc ], [ qw(cnbc_pw) ]); } +} + +sub reduce_cnbc_xlc() +{ + my ($opts, $cnbc, $xlc) = @_; + $xlc->{NV} = $cnbc->{NV}; + $xlc->{W0} = ${$cnbc->{Prior}}[0] - ${$cnbc->{Prior}}[1]; + $xlc->{NReal} = 0; + $xlc->{NCat} = $cnbc->{NV}; + for(my $j=0; $j<$cnbc->{NV}; ++$j) { + my $cpt = "CPT$j"; + my $cvj = "CVs$j"; + my ($ccl0, $ccl1) = ('C0', 'C1'); + push @{$xlc->{CDs}}, $cnbc->{$cpt}->{D}; + for(my $k=0; $k<$cnbc->{$cpt}->{D}; ++$k) { + my $vdiff = + ${$cnbc->{$cpt}->{$ccl0}}[$k] - ${$cnbc->{$cpt}->{$ccl1}}[$k]; + push @{$xlc->{$cvj}}, $vdiff; + } + } +} + + +# Format parsing functions + +sub read_acc_spec() +{ + my ($fname, $acc) = @_; + + die "Must use common parser!!!!\n"; + + open(my $fh, "<$fname") || + die "Unable to open file $fname. " . $f_err_msg; + my ($cc, $cv, $pol, $rmode) = (0, 0, 0, 0); + while(<$fh>) { + chomp; + if ($rmode == 0) { + m/\s*(\d)\s*$/ || die "Unable to match: $_\n"; + ($acc->{NC}, $rmode) = ($1, 1); + } + elsif ($rmode == 1) { + m/\s*(\d+)\s*$/ || die "Unable to match: $_\n"; + ($acc->{NV}, $rmode) = ($1, 2); + } + elsif ($rmode == 2) { + my $class = "C$cc"; + m/\s*(\-?\d+\.?\d*)\s*$/ || die "Unable to match: $_\n"; + $acc->{VV}->{$class}->{W0} = $1; + $rmode = 3; + } + elsif ($rmode == 3) { + my $class = "C$cc"; + my $polarity = "P$pol"; + m/\s*(\-?\d+\.?\d*)\s*$/ || die "Unable to match: $_\n"; + ${$acc->{VV}->{$class}->{$polarity}}[$cv] = $1; + $pol = 1 - $pol; + if ($pol == 0) { $cv++; } + if ($cv == $acc->{NV}) { + ($cc, $cv, $pol) = ($cc+1, 0, 0); + if ($cc == $acc->{NC}) { last; } + $rmode = 2; + } + } else { die "Unexpected line in file: $_\n"; } + } + close($fh); +} + +# Utilities + +sub read_opts() +{ + my ($opts) = @_; + getopts("hdvkf:o:", $opts); + + if ($opts->{h}) { + &prt_help(); + } + elsif (!defined($opts->{f})) { + die "Usage: $0 [-h] [-d] [-v] [-k] [-o <out-file>] -f <cnbc-file>\n" ; + } +} + +sub prt_help() +{ + my $tname = &toolname($0); + print <<"EOF"; +$tname: Translate CNBC format into XLC format +Usage: $tname [-h] [-d] [-v] [-k] [-o <out-file>] -f <cnbc-file> + -f <cnbc-file> specification of CNBC file + -o <out-file> output file for exporting XLC format + -k perform consistency checks & exit if error + -v verbose mode + -d debug mode + -h prints this help + Author: joao.marques-silva\@univ-toulouse.fr +EOF + exit(); +} + +sub toolname() +{ + my ($tname) = @_; + $tname =~ m/([\.\_\-a-zA-Z0-9]+)$/; + return $1; +} + + +#------------------------------------------------------------------------------# +# Auxiliary functions +#------------------------------------------------------------------------------# + +sub resolve_inc() { # Copy from template kept in UTILS package + my ($cref, $pmname) = @_; + my @progname_toks = split(/\//, $0); + pop @progname_toks; + my $progpath = join('/', @progname_toks); + my $fullname = $progpath . '/' . $pmname; + open(my $fh, "<$fullname") || die "non-existing file: $pmname\n"; + return $fh; +} + +# jpms diff --git a/pages/application/NaiveBayes/utils/generator_cnbc.py b/pages/application/NaiveBayes/utils/generator_cnbc.py new file mode 100644 index 0000000000000000000000000000000000000000..734e2759b160118e3b3cba50231082ee25179e5d --- /dev/null +++ b/pages/application/NaiveBayes/utils/generator_cnbc.py @@ -0,0 +1,175 @@ +import argparse +import pandas as pd +from sklearn.model_selection import train_test_split +from sklearn.naive_bayes import CategoricalNB +from sklearn.preprocessing import LabelEncoder +from sklearn.metrics import accuracy_score +import pickle +import os +import numpy as np +from scipy.special import logsumexp + +def predict_proba(X, clf, precision=None): + if precision == None: + feature_priors = clf.feature_log_prob_ + class_priors = clf.class_log_prior_ + else: + feature_priors = list(map(lambda x: np.log(np.clip(np.round(np.exp(x), precision), 1e-12, None)), clf.feature_log_prob_)) + class_priors = list(map(lambda x: np.log(np.clip(np.round(np.exp(x), precision), 1e-12, None)), clf.class_log_prior_)) + jll = np.zeros((X.shape[0], 2)) + for i in range(X.shape[1]): + indices = X.values[:, i] + jll += feature_priors[i][:, indices].T + total_ll = jll + class_priors + + log_prob_x = logsumexp(total_ll, axis=1) + return np.argmax(np.exp(total_ll - np.atleast_2d(log_prob_x).T), axis=1) + +if __name__ == "__main__": + parser = argparse.ArgumentParser('Categorical NBC generator.') + parser.add_argument('-d', type=str, help="dataset path") + parser.add_argument('-op', type=str, help="output pickle classifier path", default="") + parser.add_argument('-oc', type=str, help="output NBC classifier path", default="") + parser.add_argument('-oi', type=str, help="output inst path", default="") + parser.add_argument('-ox', type=str, help="output xmap path", default="") + parser.add_argument('-v', type=int, help="verbose", default=0) + parser.add_argument('-p', type=int, help="precision of classifier", default=None) + args = parser.parse_args() + + df = pd.read_csv(args.d) + df.columns = [s.strip() for s in df.columns.values] + + encoders = dict() + min_categories = dict() + for column in df.columns: + if df[column].apply(type).eq(str).all(): + df[column] = df[column].str.strip() + enc = LabelEncoder() + enc.fit(df[column]) + df[column] = enc.transform(df[column]) + min_categories[column] = len(enc.classes_) + encoders[column] = enc + + X = df.drop(df.columns[-1], axis=1) + y = df[df.columns[-1]] + + X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, random_state=0) + clf = CategoricalNB(min_categories=np.array(list(min_categories.values())).astype(int)[:-1]) + clf.fit(X_train, y_train) + + if args.v: + print("----------------------") + print("Initial accuracy:") + print("Train accuracy: ", accuracy_score(clf.predict(X_train), y_train)) + print("Test accuracy: ", accuracy_score(clf.predict(X_test), y_test)) + print("----------------------") + + if args.p is not None: + print("----------------------") + print("Rounded accuracy (precision=" + str(args.p) + "):") + print("Train accuracy: ", accuracy_score(predict_proba(X_train, clf, args.p), y_train)) + print("Test accuracy: ", accuracy_score(predict_proba(X_test, clf, args.p), y_test)) + print("----------------------") + + if args.ox: + if not os.path.exists(os.path.dirname(args.ox)): + os.makedirs(os.path.dirname(args.ox)) + + with open(args.ox, "w") as f: + # --------- Target ----------- + enc = encoders[y.name] + C = len(enc.classes_) + f.write(str(C) + "\n") + for category, target in zip(enc.classes_, enc.transform(enc.classes_)): + f.write(str(target) + " " + str(category) + "\n") + + # --------- Features --------- + n = X.shape[1] + f.write(str(n) + "\n") + + f.write("0" + "\n") + f.write(str(n) + "\n") + for i, feature in enumerate(X.columns): + f.write(str(i) + " " + str(feature) + "\n") + enc = encoders[feature] + f.write(str(len(enc.classes_)) + "\n") + for category, label in zip(enc.classes_, enc.transform(enc.classes_)): + f.write(str(label) + " " + str(category) + "\n") + + """ + FUTURE DEVELOPMENT + # Get types of features (categorical or continuous (=real-valued)) + dtypes = dict() + for column in X.columns: + if len(X[column].unique()) < (X.shape[0] / 3): + dtypes[column] = "categorical" + else: + dtypes[column] = "continuous" + # Real-valued features + f.write(str(len(dict((k, v) for k, v in dtypes.items() if v == "continuous"))) + "\n") + for i, (feature, dtype) in enumerate(dtypes.items()): + if dtype == "continuous": + f.write(str(i) + " " + str(feature) + "\n") + enc = encoders[feature] + f.write(str(len(enc.classes_)) + "\n") + for category, label in zip(enc.classes_, enc.transform(enc.classes_)): + f.write(str(label) + " " + str(category) + "\n") + + # Categorical features + f.write(str(len(dict((k, v) for k, v in dtypes.items() if v == "categorical"))) + "\n") + for i, (feature, dtype) in enumerate(dtypes.items()): + if dtype == "categorical": + f.write(str(i) + " " + str(feature) + "\n") + enc = encoders[feature] + f.write(str(len(enc.classes_)) + "\n") + for category, label in zip(enc.classes_, enc.transform(enc.classes_)): + f.write(str(label) + " " + str(category) + "\n") + """ + + if args.op: + if not os.path.exists(os.path.dirname(args.op)): + os.makedirs(os.path.dirname(args.op)) + pickle.dump(clf, open(args.op, "wb")) + + if args.oc: + if not os.path.exists(os.path.dirname(args.oc)): + os.makedirs(os.path.dirname(args.oc)) + + with open(args.oc, "w") as f: + n = len(clf.classes_) + f.write(str(n) + "\n") + class_priors = np.exp(clf.class_log_prior_) + for i in class_priors: + if args.p is not None: + f.write(str(np.round(np.format_float_positional(i, trim='-'), args.p)) + "\n") + else: + f.write(str(np.format_float_positional(i, trim='-')) + "\n") + m = X.shape[1] + f.write(str(m) + "\n") + + feature_log_priors = clf.feature_log_prob_ + + for feature_log_prior in feature_log_priors: + feature_prior = np.exp(feature_log_prior) + f.write(str(feature_prior.shape[1]) + "\n") + for feature_class_prior in feature_prior: + for v in feature_class_prior: + if args.p is not None: + f.write(str(np.round(np.format_float_positional(v, trim='-'), args.p)) + " ") + else: + f.write(str(np.format_float_positional(v, trim='-')) + " ") + f.write("\n") + + if args.oi: + if not os.path.exists(os.path.dirname(args.oi)): + os.makedirs(os.path.dirname(args.oi)) + + name = next(s for s in reversed(args.oi.split("/")) if s) + for i, (_, sample) in enumerate(X.iterrows()): + path = os.path.join(args.oi, name + "." + str(i+1) + ".txt") + with open(path, "w") as f: + f.write(str(len(sample)) + "\n") + for value in sample: + f.write(str(value) + "\n") + f.write(str(clf.predict([sample])[0]) + "\n") + diff --git a/pages/application/NaiveBayes/utils/test.pl b/pages/application/NaiveBayes/utils/test.pl new file mode 100644 index 0000000000000000000000000000000000000000..0736495ee74c54712523b07d421d015293610012 --- /dev/null +++ b/pages/application/NaiveBayes/utils/test.pl @@ -0,0 +1,3 @@ +print "Called::\n"; + +my $f_err_msg = "Please check file name, existence, permissions, etc.\n"; diff --git a/pages/application/NaiveBayes/utils/xpxlc.pl b/pages/application/NaiveBayes/utils/xpxlc.pl new file mode 100755 index 0000000000000000000000000000000000000000..cbd47596c624e073c40595c393c73e980ab429f8 --- /dev/null +++ b/pages/application/NaiveBayes/utils/xpxlc.pl @@ -0,0 +1,722 @@ +#!/usr/bin/env perl + +## Tool for reasoning about explanations in XLC's. Starting from a +## XLC and associated instance, the tool can enumerate one or more +## explanations. The tool can also validate explanations given in a +## file of explanations. The default mode of operation is to enumerate +## all explanations. +## One example of an ML model that can be reduced to XLC is the NBC. +## The details of the algorithm are included in the accompanying paper. +## The script specifically assumes classification problems with *two* +## classes. The handling of multiple classes is beyond the scope of +## the work. + +## To run the tool: +## <script-name> [-h] [-d] [-v] [-C] [-t] [-s] [-w] [-x] [-k <KKK>] [-n <NNN>] [-p <prt-file>] [-c <xpl-file> [-r]] [-m <xmap-file] -i <cat-inst-file> -f <xlc-file> + +push @INC, \&resolve_inc; + +use strict; +use warnings; + +use Getopt::Std; +use List::Util qw(sum0); ##qw( max min sum sum0); + +use constant DBG => 0; ## Also, comment out unused 'uses' +use constant CHK => 0; + +require Parsers; +import Parsers qw( parse_xlc parse_instance parse_explanations parse_xmap ); + +# 0. Read command line arguments +my %opts = (); +&read_opts(\%opts); + +if ((CHK || DBG) && (defined($opts{k}) || defined($opts{d}))) { + ## Uncomment to use assertions && debug messages + #use Carp::Assert; # Assertions are on. + #if (DBG && $opts{d}) { + # use Data::Dumper; + #} +} +if (defined($opts{p})) { + open ($opts{FH}, '>', $opts{p}); + select($opts{FH}); +} + + +# 1a. Data structures +my %xlc = (); +my %xmap = (); +my %inst = (); +my %xpl = (); + +# 1b. Prepare interrupts +if ($opts{C}) { # If catching system signals + &Utils::register_handlers(); + &Utils::push_arg(\%opts); + &Utils::push_arg(\%xlc); + if ($opts{t}) { &Utils::push_callback(\&print_stats_int); } + if ($opts{s}) { &Utils::push_callback(\&print_summaries_int); } +} + +# 2. Parse NBC XLC +&parse_xlc(\%opts, \%xlc, $opts{f}); +if (DBG && $opts{d}) { print Data::Dumper->Dump([ \%xlc ], [ qw(xlc) ]); } +if (CHK && $xlc{NReal}!=0) { die "Unable to handle real-valued features.\n"; } + +# 3. Parse instance +&parse_instance(\%opts, \%inst, $opts{i}); +if (DBG && $opts{d}) { print Data::Dumper->Dump([ \%inst ], [ qw(inst) ]); } + +# 4. If map specified, load map +if (defined($opts{m})) { + &parse_xmap(\%opts, \%xmap, $opts{m}); +} else { + &set_def_xmap(\%opts, \%xmap, \%xlc); +} +if (DBG && $opts{d}) { print Data::Dumper->Dump([ \%xmap ], [ qw(xmap) ]); } + +# 5. Compute XLC values & preprocess XLC +&simulate_xlc(\%opts, \%xlc, \%inst); +&preprocess_xlc(\%opts, \%xlc, \%inst); +&initialize_data(\%opts, \%xlc, \%inst, \%xmap); + +# 6. If *check* mode: read & validate one or more explanations +if ($opts{c}) { + &parse_explanations($opts{c}, \%xpl); + &validate_explanations(\%opts, \%xlc, \%inst, \%xmap, \%xpl); + &print_xpl_status(\%opts, \%xpl); + exit(); +} + +# 7. Else, compute & report explanations +if ($opts{x}) { + &compute_explanations_xl(\%opts, \%xlc, \%inst); +} +else { + &compute_explanations(\%opts, \%xlc, \%inst); +} + +# 8. Print summaries & stats +if ($opts{s}) { &print_summaries(\%opts, \%xlc); } +if ($opts{t}) { &print_stats(\%opts, \%xlc); } + +1; + +# Simulate XLC +sub simulate_xlc() +{ + my ($opts, $xlc, $inst) = @_; + + # Start with the intercept W0 + my $simval = $xlc->{W0}; + + # Add the contribution of real-value variables (currently assumed to be 0) + # ... + if (CHK && $xlc->{NReal} > 0) { + die "Simulation of real-valued features no ready yet.\n"; } + # ... + + # Add the contribution of categorical variables + for (my $i=0; $i<$xlc->{NCat}; ++$i) { + my $cvi = "CVs$i"; + $simval += ${$xlc->{$cvi}}[${$inst->{E}}[$i]]; + } + $xlc->{C} = ($simval > 0) ? 0 : 1; + $xlc->{Gamma} = abs($simval); + + # Validate results + if (CHK && defined($opts->{k})) { + assert($xlc->{C} == $inst->{C}, 'simulated prediction differs'); } + if ($xlc->{C} == 1) { &complement_parameters($opts, $xlc, $inst); } + if (DBG && $opts->{d}) { print Data::Dumper->Dump([$xlc], [qw(xlc_sim)]); } +} + +# If class is 1, then complement all values +sub complement_parameters() +{ + my ($opts, $xlc, $inst) = @_; + + $xlc->{W0} = -$xlc->{W0}; + for(my $i=0; $i<$xlc->{NReal}; ++$i) { + ${$xlc->{RVs}}[$i] = -${$xlc->{RVs}}[$i]; + } + for(my $i=0; $i<$xlc->{NCat}; ++$i) { + my $cvi = "CVs$i"; + for(my $j=0; $j<${$xlc->{CDs}}[$i]; ++$j) { + ${$xlc->{$cvi}}[$j] = -${$xlc->{$cvi}}[$j]; + } + } + $xlc->{C} = 1 - $xlc->{C}; + + if (DBG && $opts->{d}) { print Data::Dumper->Dump([$xlc], [qw(xlc_cps)]); } +} + +# Preprocess XLC +sub preprocess_xlc() +{ + my ($opts, $xlc, $inst) = @_; + # Compute delta's, Delta and Phi [$xlc->{Delta}, $xlc->{DeltaSum}] + &compute_deltas($opts, $xlc, $inst); + if (DBG && $opts->{d}) { print Data::Dumper->Dump([$xlc], [qw(xlc_Ds)]); } + # + # Sort delta's by non-increasing value [$xlc->{SortedDelta}] + &reorder_deltas($opts, $xlc, $inst); + if (DBG && $opts->{d}) { print Data::Dumper->Dump([$xlc], [qw(xlc_Sort)]); } + # + &calc_partial_sums($opts, $xlc); +} + +sub compute_deltas() +{ + my ($opts, $xlc, $inst) = @_; + + # a. For each feature, sort weights, and pick smallest + my $sumdelta = 0; + for (my $i=0; $i<$xlc->{NCat}; ++$i) { + my $cvi = "CVs$i"; + my $tval = ${$xlc->{$cvi}}[${$inst->{E}}[$i]]; + my @scvs = sort { $a <=> $b} @{$xlc->{$cvi}}; + ${$xlc->{Delta}}[$i] = $tval - $scvs[0]; + ##print ("i=$i: tval=$tval vs. minv=$scvs[0] vs. delta=${$xlc->{Delta}}[$i]\n"); + $sumdelta += ${$xlc->{Delta}}[$i]; + } + $xlc->{DeltaSum} = $sumdelta; + $xlc->{Phi} = $xlc->{DeltaSum} - $xlc->{Gamma}; + $xlc->{PhiRef} = $xlc->{Phi}; + + # b. Validations + if (DBG && $opts->{d}) { print "SumDelta: $sumdelta\n"; } + if (CHK && defined($opts->{k}) && $sumdelta <= $xlc->{Phi}) { + my $msg = 'XLC prediction cannot be changed!?'; + if ($opts->{k}>1) { &prt_err_exit($msg); } + elsif ($opts->{k}==1) { &prt_warn($msg); } + } +} + +sub reorder_deltas() +{ + my ($opts, $xlc, $inst) = @_; + + my %DMap = (); + $xlc->{DeltaMap} = {}; + for(my $i=0; $i<$xlc->{NCat}; ++$i) { + my $rval = ${$xlc->{Delta}}[$i]; + push @{$xlc->{DeltaMap}->{$rval}}, $i; + $DMap{$rval} = 1; + } + @{$xlc->{SortedDelta}} = sort { $b <=> $a } @{$xlc->{Delta}}; + @{$xlc->{SDelta}} = (); + for(my $i=0; $i<=$#{$xlc->{SortedDelta}}; ++$i) { + my $rval = ${$xlc->{SortedDelta}}[$i]; + if ($DMap{$rval} == 0) { next; } + if (DBG && $opts->{d}) { + print "A: SDelta \@ i=$i: @{$xlc->{SDelta}} && rval=$rval\n"; } + push @{$xlc->{SDelta}}, @{$xlc->{DeltaMap}->{$rval}}; + $DMap{$rval} = 0; + } + if (DBG && $opts->{d}) { print Data::Dumper->Dump([$xlc], [qw(xlc_Reord)]); } + if (CHK && defined($opts->{k})) { + my ($sdz, $dz) = ($#{$xlc->{SDelta}}+1, $#{$xlc->{Delta}}+1); + assert($sdz == $dz, "Different sizes: $sdz vs. $dz"); } +} + +sub calc_partial_sums() +{ + my ($opts, $xlc) = @_; + + my ($depth, $tmpv) = ($xlc->{NCat}-1, 0); + while($depth>=0) { + $tmpv += ${$xlc->{SortedDelta}}[$depth]; + $xlc->{SumFrom}[$depth--] = $tmpv; + } +} + +sub set_def_xmap() +{ + my ($opts, $xmap, $xlc) = @_; + + $xmap->{NC} = 2; # Default number of classes... + @{$xmap->{ClMap}} = ('0', '1'); + $xmap->{NC} = $xlc->{NV}; + $xmap->{NReal} = $xlc->{NReal}; + for (my $i=0; $i<$xlc->{NReal}; ++$i) { + ${$xmap->{VMap}}[$i] = "v$i"; + } + $xmap->{NCat} = $xlc->{NCat}; + @{$xmap->{CDs}} = @{$xlc->{CDs}}; + for (my $i=0; $i<$xlc->{NCat}; ++$i) { + my $cid = $xmap->{NReal}+$i; + ${$xmap->{VMap}}[$i] = "v$cid"; + for (my $j=0; $j<${$xlc->{CDs}}[$i]; ++$j) { + ${$xmap->{CMap}->{$i}}[$j] = "$j"; + } + } + if (DBG && $opts->{d}) { print Data::Dumper->Dump([$xmap],[qw(xmap_def)]); } +} + +sub initialize_data() +{ + my ($opts, $xlc, $inst, $xmap) = @_; + + ($xlc->{XplNum}, $xlc->{XplSz}) = (0, 0); + ($xlc->{XplMin}, $xlc->{XplMax}) = ($xlc->{NV}, 0); + for(my $idx=0; $idx<=$xlc->{NV}; ++$idx) { + ${$xlc->{CNTS}}[$idx] = 0; + } + for(my $idx=0; $idx<$xlc->{NReal}; ++$idx) { + die "Handling of real-valued features not yet implemented...\n"; + } + for(my $idx=0; $idx<$xlc->{NCat}; ++$idx) { + # Categorical feature name + my $cval = ${$inst->{E}}[$idx]; + my $vname = ${$xmap->{VMap}}[$xmap->{NReal}+$idx]; + my $cname = ${$xmap->{CMap}->{$idx}}[$cval]; + ${$xlc->{LITS}}[$idx] = "$vname=$cname"; + } + if (DBG && $opts->{d}) { print Data::Dumper->Dump([$xlc], [qw(xlc_init)]); } +} + + +# Reference implementation +sub compute_explanations() +{ + my ($opts, $xlc, $inst) = @_; + + ##if ($xlc->{Lambda} < 0) { print ("Expl: true\n"); return; } ???? + # + my @xp = (-1) x ($xlc->{NV}+1); my @tog = (-1) x $xlc->{NV}; my $depth=-1; + my $cntxp = (defined($opts->{n})) ? 1 : 0; my $numxp = 0; $xp[0] = 1; + if (DBG && $opts->{d}) { print Data::Dumper->Dump([$xlc], [qw(xlc_xpA)]); } + while (1) { + # 1. Find another explanation + #if (DBG && $opts->{d}) { print("\@Depth: $depth\n"); + # &prt_xp_snapshot($xlc,\%xp,\@tog,$depth,1); } + $depth = &find_one_explanation($opts, $xlc, $inst, \@tog, \@xp, $depth); + &report_explanation($opts, $xlc, \@xp); + if ($cntxp && ++$numxp == $opts->{n}) { last; } + # 2. Enter consistent state + $depth = &enter_valid_state($opts, $xlc, $inst, \@tog, \@xp, $depth); + if ($depth < 0) { return; } + if (DBG && $opts->{d}) { &prt_xp_snapshot($xlc,\@xp,\@tog,$depth,0); } + } + if (DBG && $opts->{d}) { print Data::Dumper->Dump([$xlc], [qw(xlc_xpB) ]); } +} + +sub find_one_explanation() +{ + my ($opts, $xlc, $inst, $tog, $xp, $idx) = @_; + + while ($xlc->{Phi} >= 0) { + if (DBG && defined($opts->{d})) { print "Depth(down): $idx\n"; } + if (CHK && defined($opts->{k})) { + assert($idx<$xlc->{NV}); + assert($idx==$xlc->{NV}-1 || ${$tog}[$idx+1]==-1); } + ${$tog}[++$idx] = 0; + $xlc->{Phi} -= ${$xlc->{SortedDelta}}[$idx]; + ®_literal($opts, $xp, $xlc, $inst, $idx); + if (DBG && $opts->{d}) { &prt_xp_snapshot($xlc, $xp, $tog, $idx, 0); } + } + if (CHK && defined($opts->{k})) { + assert($xlc->{Phi}<0); &chk_explanation($opts, $xlc, $xp, $tog); } + return $idx; +} + +sub enter_valid_state() +{ + my ($opts, $xlc, $inst, $tog, $xp, $idx) = @_; + + while (!&consistent_state($opts, $xlc, $idx)) { + if (DBG && defined($opts->{d})) { print "Depth(up): $idx\n"; } + while ($idx>=0 && ${$tog}[$idx]==1) { ${$tog}[$idx--] = -1; } + if ($idx < 0) { return $idx; } # Terminate + # Drop literal from explanation + if (CHK && defined($opts->{k})) { assert(${$tog}[$idx]==0); } + &unreg_literal($opts, $xp, $xlc, $inst, $idx); + $xlc->{Phi} += ${$xlc->{SortedDelta}}[$idx]; + if (CHK && defined($opts->{k})) { assert(${$tog}[$idx]==0); } + ${$tog}[$idx] = 1; + if (DBG && $opts->{d}) { &prt_xp_snapshot($xlc, $xp, $tog, $idx, 1); } + } + return $idx; +} + +sub consistent_state() +{ + my ($opts, $xlc, $idx) = @_; + + my $stok = + ($xlc->{Phi} < 0 || $idx == $xlc->{NV}-1 || + ${$xlc->{SumFrom}}[$idx+1] <= $xlc->{Phi}); + return ($stok) ? 0 : 1; +} + +sub reg_literal() +{ + my ($opts, $xp, $xlc, $inst, $idx) = @_; + my $lit = ${$xlc->{SDelta}}[$idx]; + if (CHK) { assert(${$xp}[0] <= $#{$xp}, "XP idx above limit??"); } + ${$xp}[${$xp}[0]++] = $lit; + if (CHK) { assert(${$xp}[${$xp}[0]] == -1, "Pointing to wrong pos!?"); } +} + +sub unreg_literal() +{ + my ($opts, $xp, $xlc, $inst, $idx) = @_; + if (CHK) { assert(${$xp}[0] > 0, "XP idx below limit??"); } + ${$xp}[--${$xp}[0]] = -1; + if (CHK) { assert(${$xp}[${$xp}[0]] == -1, "Pointing to wrong pos!?"); } +} + +sub report_explanation() +{ + my ($opts, $xlc, $xp) = @_; + + # Obs: No actual need to sort; we can keep a sorted list. This is faster... + if ($opts->{w}) { + ##$" = ', '; + my $tlits = $xlc->{LITS}; + my @slice = @{$xp}[1 .. (${$xp}[0]-1)]; + if (DBG && $opts->{d}) { print ("Slice: @slice\n"); } + my @sslice = sort { ${$tlits}[$a] cmp ${$tlits}[$b] } @slice; + if (DBG && $opts->{d}) { print ("Sorted Slice: @slice\n"); } + my @xplits = map { ${$xlc->{LITS}}[$_] } @sslice; + if (DBG && $opts->{d}) { print ("Exp Lits: @xplits\n"); } + + if (CHK && $opts->{k}) { + for(my $i=1; $i<=$#xplits; ++$i) { + assert($xplits[$i-1] ne $xplits[$i], + "Duplicate literals in explanation: $xplits[$i-1] vs. $xplits[$i]\n" . + "Exp: @xplits\n"); + } + } + + #my @xplits = sort {abs($a) <=> abs($b)} keys %{$xp}; + if (!$opts->{v}) { + ##local $"=', '; + print("Expl: @xplits\n"); + } + else { + my $sz = sprintf("_(\#%d/%d", $#xplits+1,$xlc->{NV}); + my $wt = (defined($opts->{k})) ? + sprintf(";W:%3.2f)", $xlc->{Phi}) : ')'; + ##{ ##local $"=', '; + print("Expl$sz$wt: @xplits\n"); + ## } + } + } + if ($opts->{t} || $opts->{s}) { + if ($opts->{t}) { + my $nlits = ${$xp}[0]-1; + $xlc->{XplSz} += $nlits; + if ($xlc->{XplMin} > $nlits) { $xlc->{XplMin} = $nlits; } + if ($xlc->{XplMax} < $nlits) { $xlc->{XplMax} = $nlits; } + } + if ($opts->{s}) { + my ($Cnts, $num) = ($xlc->{CNTS}, $xp->[0]); + for (my $idx=1; $idx<$num; ++$idx) { + ${$Cnts}[$xp->[$idx]]++; + } + } + $xlc->{XplNum}++; + } + if (DBG && $opts->{d}) { &prt_flush(); } +} + +sub prt_xp_snapshot() +{ + my ($xlc, $xp, $tog, $depth, $mf) = @_; + + my $msg = ($mf) ? '@Up:' : '@Down:'; + print ("$msg\n"); + print ("Phi: $xlc->{Phi}\n"); + print ("Deltas: [ @{$xlc->{SortedDelta}} ]\n"); + print ("SDelta: [ @{$xlc->{SDelta}} ]\n"); + print ("CNTS: [ @{$xlc->{CNTS}} ]\n"); + print ("LITS: [ @{$xlc->{LITS}} ]\n"); + my $lstidx = ${$xp}[0]-1; + print ("XP keys: ${$xp}[0] + [ @{$xp}[1..$lstidx] ]\n"); + print ("XP vect: [ @{$xp} ]\n"); + print ("Togs: [ @{$tog} ]\n"); + print ("Depth: $depth\n"); + &prt_flush(); +} + +sub chk_explanation() +{ + my ($opts, $xlc, $xp, $tog) = @_; + + my ($phi, $ntogs) = ($xlc->{PhiRef}, 0); + for(my $i=0; $i<=$#{$tog}; ++$i) { + if (${$tog}[$i]==0) { + $phi -= ${$xlc->{SortedDelta}}[$i]; + $ntogs++; + } + } + assert($phi < 0); + assert($ntogs == ${$xp}[0]-1); +} + + +# Alternative (faster) implementation +sub compute_explanations_xl() +{ + my ($opts, $xlc, $inst) = @_; + + my @SortedDelta = @{$xlc->{SortedDelta}}; + my @SDelta = @{$xlc->{SDelta}}; + my @SumFrom = @{$xlc->{SumFrom}}; + my @xp = (-1) x ($xlc->{NV}+1); my @tog = (-1) x $xlc->{NV}; my $depth=-1; + my $cntxp = (defined($opts->{n})) ? 1 : 0; my $numxp = 0; $xp[0]=1; + if (DBG && $opts->{d}) { print Data::Dumper->Dump([$xlc], [qw(xlc_xpA)]); } + while (1) { + # 1. Find another explanation + while ($xlc->{Phi} >= 0) { + if (DBG && defined($opts->{d})) { print "Depth(down): $depth\n"; } + if (CHK && defined($opts->{k})) { + assert($depth<$xlc->{NV}); + assert($depth==$xlc->{NV}-1 || $tog[$depth+1]==-1); } + $tog[++$depth] = 0; + $xlc->{Phi} -= $SortedDelta[$depth]; + $xp[$xp[0]++] = $SDelta[$depth]; + if (DBG && $opts->{d}) { + &prt_xp_snapshot($xlc,\@xp,\@tog,$depth,0); } + } + if (CHK && defined($opts->{k})) { + assert($xlc->{Phi}<0); &chk_explanation($opts,$xlc,\@xp,\@tog); } + &report_explanation($opts, $xlc, \@xp); + if ($cntxp && ++$numxp == $opts->{n}) { last; } + + # 2. Enter consistent state + while ($xlc->{Phi} < 0 || $depth == $xlc->{NV}-1 || + $SumFrom[$depth+1] <= $xlc->{Phi}) { + if (DBG && defined($opts->{d})) { print "Depth(up): $depth\n"; } + while ($depth>=0 && $tog[$depth]==1) { $tog[$depth--] = -1; } + if ($depth < 0) { return $depth; } # Terminate + # Drop literal from explanation + if (CHK && defined($opts->{k})) { assert($tog[$depth]==0); } + $xp[--$xp[0]] = 0; + $xlc->{Phi} += $SortedDelta[$depth]; + if (CHK && defined($opts->{k})) { assert($tog[$depth]==0); } + $tog[$depth] = 1; + if (DBG && $opts->{d}) { + &prt_xp_snapshot($xlc,\@xp,\@tog,$depth,1); } + } + if ($depth < 0) { return; } + if (DBG && $opts->{d}) { &prt_xp_snapshot($xlc,\@xp,\@tog,$depth,0); } + } + if (DBG && $opts->{d}) { print Data::Dumper->Dump([$xlc], [qw(xlc_xpB) ]); } +} + +sub validate_explanations() +{ + my ($opts, $xlc, $inst, $xmap, $xpl) = @_; + + %{$xmap->{IVMap}} = (); + for (my $i=0; $i<=$#{$xmap->{VMap}}; ++$i) { + $xmap->{IVMap}->{${$xmap->{VMap}}[$i]} = $i; + } + # Traverse & validate given explanations + ($xpl->{XPStr}, $xpl->{Status}, $xpl->{RedLits}) = ([], [], []); + foreach my $xpvec (@{$xpl->{Expl}}) { + push @{$xpl->{XPStr}}, "@{$xpvec}"; + # 1. Check entailment + my $phi = $xlc->{PhiRef}; + foreach my $lit (@{$xpvec}) { + $lit =~ m/([^=]+)=([^=]+)/ || die "Unable to match literal: $lit\n"; + my ($svar, $sval) = ($1, $2); + ##print ("(svar,sval)=($svar,$sval)\n"); + ##print ("IVMap{svar}: $xmap->{IVMap}->{$svar}\n"); + my $var = $xmap->{IVMap}->{$svar}-$xmap->{NReal}; + $phi -= ${$xlc->{Delta}}[$var]; + ##print ("Current Phi:$phi\n"); + } + if ($phi >= 0) { + push @{$xpl->{Status}}, -1; + push @{$xpl->{RedLits}}, []; + next; + } + # 2. Check redundancy + if (CHK && defined($opts->{k})) { assert($phi < 0); } + my $RedLits = []; + foreach my $lit (@{$xpvec}) { + $lit =~ m/([^=]+)=([^=]+)/ || die "Unable to match literal: $lit\n"; + my ($svar, $sval) = ($1, $2); + my $var = $xmap->{IVMap}->{$svar}-$xmap->{NReal}; + if ($phi + $xlc->{Delta}[$var] < 0) { push @{$RedLits}, $lit; } + } + push @{$xpl->{RedLits}}, $RedLits; + if (@{$RedLits}) { push @{$xpl->{Status}}, 1; next; } + push @{$xpl->{Status}}, 0; + } + return; +} + +sub print_xpl_status() +{ + my ($opts, $xpl) = @_; + + ###($xpl->{XPStr}, $xpl->{Status}, $xpl->{RedLits}) = ([], [], []); + for(my $i=0; $i<=$#{$xpl->{XPStr}}; ++$i) { + print ("Expl: ${$xpl->{XPStr}}[$i] => "); + my $xpst = ${$xpl->{Status}}[$i]; + my ($msg, $redlits) = ('', ''); + if ($xpst == 0) { + $msg = 'Confirmed as (subset-minimal) explanation'; + } + elsif ($xpst < 0) { + $msg = 'NOT an explanation, i.e. entailment does not hold'; + } + else { + $msg = 'Redundant explanation. Example of redundant literals: '; + $redlits = "@{${$xpl->{RedLits}}[$i]}"; + } + print ("$msg$redlits\n"); + } +} + +sub print_stats() +{ + my ($opts, $xlc) = @_; + + my $tname = uc(&toolname($0)); + print "\n$tname stats:\n"; + my $tsz = (defined($opts->{n})) ? "$opts->{n}" : 'all'; + print ("Target explanations: $tsz\n"); + my $avgsz = sprintf("%.2f", $xlc->{XplSz} / $xlc->{XplNum}); + print ("Number of explanations: $xlc->{XplNum}\n"); + print ("Average explanation size: $avgsz\n"); + print ("Smallest explanation: $xlc->{XplMin}\n"); + print ("Largest explanation: $xlc->{XplMax}\n"); +} + +sub print_summaries() +{ + my ($opts, $xlc) = @_; + + my $tname = uc(&toolname($0)); + print "\n$tname summary:\n"; + my $hsz = 0; + for (my $idx=0; $idx <= $#{$xlc->{CNTS}}; ++$idx) { + if (${$xlc->{CNTS}}[$idx] != 0) { $hsz++; } + } + my $tsz = (defined($opts->{n})) ? "$opts->{n}" : 'all'; + print "Target explanations: $tsz\n"; + my $avgsz = sprintf("%.2f", $xlc->{XplSz} / $xlc->{XplNum}); + print "Number of explanations: $xlc->{XplNum}\n"; + print "Histogram size: $hsz\n"; + print "Literal distribution in explanations:\n"; + my $tcnts = $xlc->{CNTS}; + my @skeys = (0 .. $xlc->{NV}-1); + @skeys = sort { abs(${$tcnts}[$a]) <=> abs(${$tcnts}[$b]) } @skeys; + foreach my $key (@skeys) { + next if ${$xlc->{CNTS}}[$key] <= 0; + my $lit = ${$xlc->{LITS}}[$key]; + print("$lit: ${$xlc->{CNTS}}[$key]\n"); + } +} + +sub print_stats_int() +{ + my $args = shift @_; + + my ($opts, $xlc) = @{$args}; + &print_stats($opts, $xlc); +} + +sub print_summaries_int() +{ + my $args = shift @_; + + my ($opts, $xlc) = @{$args}; + &print_summaries($opts, $xlc); +} + + +# Utilities + +sub read_opts() +{ + my ($opts) = @_; + getopts("hdvCtswxk:n:c:rp:m:f:i:", $opts); + + if ($opts->{h}) { + &prt_help(); + } + elsif (!defined($opts->{f}) || !defined($opts->{i})) { + ##|| + ##(defined($opts->{c}) && defined($opts->{i})) || + ##(!defined($opts->{c}) && !defined($opts->{i}))) { + die "Usage: $0 [-h] [-d] [-v] [-C] [-t] [-s] [-w] [-x] [-k <KKK>] [-n <NNN>] [-p <prt-file>] [-c <chk-xpl> [-r]] [-m <xmap-file>] -i <cat-inst-file> -f <xlc-file>\n" ; + } +} + +sub prt_help() +{ + my $tname = &toolname($0); + print <<"EOF"; +$tname: Compute explanations of XLCs (including NBCs) with polynomial delay +Usage: $tname [-h] [-d] [-v] [-C] [-t] [-s] [-w] [-x] [-k <KKK>] [-n <NNN>] [-p <prt-file>] [-c <xpl-file> [-r]] [-m <xmap-file>] -i <cat-inst-file> -f <xlc-file> + -f <xlc-file> specification of XLC file + -i <inst-file> specification of instance + -c <xpl-file> check/validate explanation + -m <xmap-file> map file + -p <prt-file> print to file + -n <NNN> number of NNN explanations to list (the default is all) + -k <KKK> apply consistency checks & issue warnings (1) or exit (>1) + -r repair explanations (when validating explanations) [not yet available] + -x run faster implementation + -w write computed explanations + -s summarize computed explanations + -t gather stats on computed explanations + -C enable catching system signals + -v verbose mode + -d debug mode + -h prints this help + Author: joao.marques-silva\@univ-toulouse.fr +EOF + exit(); +} + +sub prt_warn() +{ + my ($msg) = @_; + print("*** $0 warning ***: $msg\n"); +} + +sub prt_err_exit() +{ + my ($msg) = @_; + print("*** $0 error ***: $msg\n"); + exit(); +} + +sub toolname() +{ + my ($tname) = @_; + $tname =~ m/([\.\_\-a-zA-Z0-9]+)$/; + return $1; +} + +sub prt_flush() +{ + select()->flush(); +} + + +#------------------------------------------------------------------------------# +# Auxiliary functions +#------------------------------------------------------------------------------# + +sub resolve_inc() { # Copy from template kept in UTILS package + my ($cref, $pmname) = @_; + my @progname_toks = split(/\//, $0); + pop @progname_toks; + my $progpath = join('/', @progname_toks); + my $fullname = $progpath . '/' . $pmname; + open(my $fh, "<$fullname") || die "non-existing file: $pmname\n"; + return $fh; +} + +# jpms diff --git a/pages/application/application.py b/pages/application/application.py index db54577aa48a31e87069afd59ebf34ed6a346ea0..29e01f9bc87e6f9d3c57e82ebb2547b17ca70a17 100644 --- a/pages/application/application.py +++ b/pages/application/application.py @@ -3,6 +3,8 @@ import dash_bootstrap_components as dbc import dash_daq as daq from pages.application.DecisionTree.DecisionTreeComponent import DecisionTreeComponent +from pages.application.NaiveBayes.NaiveBayesComponent import NaiveBayesComponent +import subprocess class Application(): def __init__(self, view): @@ -31,7 +33,10 @@ class Model(): self.instance = '' self.list_expls = [] - self.expl_path = [] + self.list_cont_expls = [] + + self.expl='' + self.cont_expl='' self.component_class = '' self.component = '' @@ -56,24 +61,28 @@ class Model(): def update_instance(self, instance): self.instance = instance - self.list_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) + self.list_expls, self.list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) def update_enum(self, enum): self.enum = enum - self.list_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) + self.list_expls, self.list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) def update_xtype(self, xtype): self.xtype = xtype - self.list_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) + self.list_expls, self.list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) def update_solver(self, solver): self.solver = solver - self.list_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) + self.list_expls, self.list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) def update_expl(self, expl): self.expl = expl self.component.draw_explanation(self.instance, expl) + def update_cont_expl(self, cont_expl): + self.expl = cont_expl + self.component.draw_contrastive_explanation(self.instance, cont_expl) + class View(): def __init__(self, model): @@ -178,12 +187,16 @@ class View(): ], className="sidebar")]) - self.expl_choice = html.Div([html.H5(id = "navigate_label", hidden=True, children="Navigate through the explanations and plot them on the tree : "), - html.Div(id='navigate_dropdown', hidden=True, - children = [dcc.Dropdown(self.model.list_expls, + self.expl_choice = html.Div(id = "interaction_graph", hidden=True, + children=[html.H5("Navigate through the explanations and plot them on the tree : "), + html.Div(children = [dcc.Dropdown(self.model.list_expls, id='expl_choice', + className="dropdown")]), + html.H5("Navigate through the contrastive explanations and plot them on the tree : "), + html.Div(children = [dcc.Dropdown(self.model.list_cont_expls, + id='cont_expl_choice', className="dropdown")])]) - + self.layout = dbc.Row([ dbc.Col([self.sidebar], width=3, class_name="sidebar"), dbc.Col([dbc.Row(id = "graph", children=[]), diff --git a/utils.py b/utils.py index 42b30c99af3af1b8463b3aa8e29f1834e974e0b3..d9996f5443cd611798e68d37beb3f73d70aa4cc3 100644 --- a/utils.py +++ b/utils.py @@ -27,9 +27,9 @@ def parse_contents_data(contents, filename): decoded = base64.b64decode(content_string) try: if '.csv' in filename: - data = decoded.decode('utf-8') + data = decoded.decode('utf-8').strip() if '.txt' in filename: - data = decoded.decode('utf-8') + data = decoded.decode('utf-8').strip() except Exception as e: print(e) return html.Div([ @@ -51,7 +51,11 @@ def parse_contents_instance(contents, filename): data = str(data).strip().split(',') data = list(map(lambda i: tuple([i[0], np.float32(i[1])]), [i.split('=') for i in data])) elif '.json' in filename: - data = decoded.decode('utf-8') + data = decoded.decode('utf-8').strip() + data = json.loads(data) + data = list(tuple(data.items())) + elif '.inst' in filename: + data = decoded.decode('utf-8').strip() data = json.loads(data) data = list(tuple(data.items())) except Exception as e: