Skip to content
Snippets Groups Projects
Commit 0e0b3aa4 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

added naive bayes

parent e4166e3e
No related branches found
No related tags found
1 merge request!4end Decision tree
This commit is part of merge request !4. Comments created here will be created in the context of that merge request.
import dash
import pandas as pd
from dash import Input, Output, State
from dash.dependencies import Input, Output, State
from dash.exceptions import PreventUpdate
from utils import parse_contents_graph, parse_contents_instance, parse_contents_data
def register_callbacks(page_home, page_course, page_application, app):
page_list = ['home', 'course', 'application']
@app.callback(
Output('page-content', 'children'),
Input('url', 'pathname'))
def display_page(pathname):
if pathname == '/':
return page_home
if pathname == '/application':
return page_application.view.layout
if pathname == '/course':
return page_course
@app.callback(Output('home-link', 'active'),
Output('course-link', 'active'),
Output('application-link', 'active'),
Input('url', 'pathname'))
def navbar_state(pathname):
active_link = ([pathname == f'/{i}' for i in page_list])
return active_link[0], active_link[1], active_link[2]
@app.callback(
Output('graph', 'children'),
Input('ml_model_choice', 'value'),
prevent_initial_call=True
)
def update_ml_type(value_ml_model):
model_application = page_application.model
model_application.update_ml_model(value_ml_model)
return None
@app.callback(
Output('pretrained_model_filename', 'children'),
Output('graph', 'children'),
Input('ml_pretrained_model_choice', 'contents'),
State('ml_pretrained_model_choice', 'filename'),
prevent_initial_call=True
)
def update_ml_pretrained_model(pretrained_model_contents, pretrained_model_filename):
model_application = page_application.model
if model_application.ml_model is None :
raise PreventUpdate
graph = parse_contents_graph(pretrained_model_contents, pretrained_model_filename)
model_application.update_pretrained_model(graph)
if not model_application.add_info :
model_application.update_pretrained_model_layout()
return pretrained_model_filename, model_application.component.network
else :
return pretrained_model_filename, None
@app.callback(
Output('info_filename', 'children'),
Output('graph', 'children'),
Input('model_info_choice', 'contents'),
State('model_info_choice', 'filename'),
prevent_initial_call=True
)
def update_info_model(model_info, model_info_filename):
model_application = page_application.model
if model_application.ml_model is None :
raise PreventUpdate
model_info = parse_contents_data(model_info, model_info_filename)
model_application.update_pretrained_model_layout_with_info(model_info, model_info_filename)
return model_info_filename, model_application.component.network
@app.callback(
Output('instance_filename', 'children'),
Output('graph', 'children'),
Output('explanation', 'children'),
Input('ml_instance_choice', 'contents'),
State('ml_instance_choice', 'filename'),
prevent_initial_call=True
)
def update_instance(instance_contents, instance_filename):
model_application = page_application.model
if model_application.ml_model is None or model_application.pretrained_model is None or model_application.enum<=0 or model_application.xtype is None :
raise PreventUpdate
instance = parse_contents_instance(instance_contents, instance_filename)
model_application.update_instance(instance)
return instance_filename, model_application.component.network, model_application.component.explanation
@app.callback(
Output('explanation', 'children'),
Input('number_explanations', 'value'),
prevent_initial_call=True
)
def update_enum(enum):
model_application = page_application.model
if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.xtype is None:
raise PreventUpdate
model_application.update_enum(enum)
return model_application.component.explanation
@app.callback(
Output('explanation', 'children'),
Input('explanation_type', 'value'),
prevent_initial_call=True
)
def update_xtype(xtype):
model_application = page_application.model
if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.enum<=0 :
raise PreventUpdate
model_application.update_xtype(xtype)
return model_application.component.explanation
@app.callback(
Output('explanation', 'children'),
Input('solver_sat', 'value'),
prevent_initial_call=True
)
def update_solver(solver):
model_application = page_application.model
if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.enum<=0 or len(model_application.xtype)==0:
raise PreventUpdate
model_application.update_solver(solver)
return model_application.component.explanation
@app.callback(
Output('graph', 'children'),
Input('expl_choice', 'value'),
prevent_initial_call=True
)
def update_expl_choice( expl_choice):
model_application = page_application.model
if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.enum<=0 or len(model_application.xtype)==0:
raise PreventUpdate
model_application.update_expl(expl_choice)
return model_application.component.network
@app.callback(
Output('explanation', 'hidden'),
Output('navigate_label', 'hidden'),
Output('navigate_dropdown', 'hidden'),
Output('expl_choice', 'options'),
Input('explanation', 'children'),
Input('explanation_type', 'value'),
prevent_initial_call=True
)
def layout_buttons_navigate_expls(explanation, explanation_type):
if explanation is None or len(explanation_type)==0:
return True, True, True, {}
elif "AXp" not in explanation_type and "CXp" in explanation_type:
return False, True, True, {}
else :
options = {}
model_application = page_application.model
for i in range (len(model_application.list_expls)):
options[str(model_application.list_expls[i])] = model_application.list_expls[i]
return False, False, False, options
@app.callback(
Output('choice_info_div', 'hidden'),
Input('add_info_model_choice', 'on'),
prevent_initial_call=True
)
def add_model_info(add_info_model_choice):
model_application = page_application.model
model_application.update_info_needed(add_info_model_choice)
if add_info_model_choice:
return False
else :
return True
......@@ -23,6 +23,8 @@ def create_legend(g):
legend.add_node("b", style = "invis")
legend.add_node("c", style = "invis")
legend.add_node("d", style = "invis")
legend.add_node("e", style = "invis")
legend.add_node("f", style = "invis")
legend.add_edge("a","b")
edge = legend.get_edge("a","b")
......@@ -35,7 +37,10 @@ def create_legend(g):
edge.attr["color"] = "blue"
edge.attr["style"] = "dashed"
legend.add_edge("e","f")
edge = legend.get_edge("e","f")
edge.attr["label"] = "contrastive explanation"
edge.attr["color"] = "red"
#
#==============================================================================
def visualize(dt):
......
from os import path
import base64
import dash_bootstrap_components as dbc
import numpy as np
from dash import dcc, html
import subprocess
import shlex
class NaiveBayesComponent():
def __init__(self, model, type_model='SKL', info=None, type_info=''):
#Conversion model
p=subprocess.Popen(['perl','pages/application/NaiveBayes/utils/cnbc2xlc.pl', model],stdout=subprocess.PIPE)
print(p.stdout.read())
self.naive_bayes = model
self.map_file = ""
self.network = html.Div([])
self.explanation = []
def update_with_explicability(self, instance, enum, xtype, solver) :
# Call explanation
p=subprocess.Popen(['perl','pages/application/NaiveBayes/utils/xpxlc.pl', self.naive_bayes, instance, self.map_file],stdout=subprocess.PIPE)
print(p.stdout.read())
self.explanation = []
list_explanations_path=[]
explanation = {}
self.network = html.Div([])
#Creating a clean and nice text component
#instance plotting
self.explanation.append(html.H4("Instance : \n"))
self.explanation.append(html.P(str([str(instance[i]) for i in range (len(instance))])))
for k in explanation.keys() :
if k != "List of path explanation(s)" and k!= "List of path contrastive explanation(s)" :
if k in ["List of abductive explanation(s)","List of contrastive explanation(s)"] :
self.explanation.append(html.H4(k))
for expl in explanation[k] :
self.explanation.append(html.Hr())
self.explanation.append(html.P(expl))
self.explanation.append(html.Hr())
else :
self.explanation.append(html.P(k + explanation[k]))
else :
list_explanations_path = explanation["List of path explanation(s)"]
list_contrastive_explanations_path = explanation["List of path contrastive explanation(s)"]
return list_explanations_path, list_contrastive_explanations_path
package Parsers;
use strict;
use warnings;
use Data::Dumper;
use POSIX qw( !assert );
use Exporter;
require Utils; # Must use require, to get INC updated
import Utils qw( &get_progname &get_progpath );
BEGIN {
@Parsers::ISA = ('Exporter');
@Parsers::EXPORT_OK =
qw( &parse_xlc &parse_cnbc &parse_xmap
&parse_instance &parse_explanations
&parse_blc &parse_acc );
}
use constant F_ERR_MSG =>
"Please check file name, existence, permissions, etc.\n";
use constant HLPMAP => 1;
use constant CCAT_CH => '_';
use constant CCHK => 0;
if (CCHK) {
## Uncomment to use assertions && debug messages
#use Carp::Assert; # Assertions are on.
}
# Parse XLC format
sub parse_xlc()
{
my ($opts, $xlc, $fname) = @_;
open(my $fh, "<$fname") || die "Unable to open file $fname. " . F_ERR_MSG;
my ($nc, $nr, $rmode) = (0, 0, 0);
while(<$fh>) {
chomp;
next if m/^\s*c\s+$/;
if ($rmode == 0) { # Read number of features
m/^\s*(\d+)\s*$/ || die "Unable to match: $_\n";
($xlc->{NV}, $rmode) = ($1, 1);
}
elsif ($rmode == 1) { # Read w0
m/^\s*(\-?\d+\.?\d*)\s*$/ || die "Unable to match: $_\n";
($xlc->{W0}, $rmode) = ($1, 2);
}
elsif ($rmode == 2) { # Read number of real-valued features
m/^\s*(\d+)\s*$/ || die "Unable to match: $_\n";
($xlc->{NReal}, $rmode) = ($1, 3);
if ($xlc->{NReal} == 0) { $rmode = 4; }
}
elsif ($rmode == 3) { # Read real-valued coefficients
m/^\s*(\-?\d+\.?\d*)\s*$/ || die "Unable to match: $_\n";
push @{$xlc->{RVs}}, $1;
if (++$nr == $xlc->{NReal}) { ($nr, $rmode) = (0, 4); }
}
elsif ($rmode == 4) { # Read number of categorical features
m/^\s*(\d+)\s*$/ || die "Unable to match: $_\n";
($xlc->{NCat}, $rmode) = ($1, 5);
}
elsif ($rmode == 5) { # Read domains and weights of cat. features
my $cvi = "CVs$nc";
@{$xlc->{$cvi}} = split(/ +/);
push @{$xlc->{CDs}}, shift @{$xlc->{$cvi}};
if (++$nc == $xlc->{NCat}) { $rmode = 6; }
}
else { die "Invalid state with input: $_\n"; }
}
close($fh);
}
# Parse map file
sub parse_xmap()
{
my ($opts, $xmap, $fname) = @_;
open(my $fh, "<$fname") || die "Unable to open file $fname. " . F_ERR_MSG;
my ($cc, $nv, $nc, $nr, $rmode) = (0, 0, 0, 0, 0);
while(<$fh>) {
chomp;
next if m/^\s*c\s+$/;
if ($rmode == 0) { # Read number of classes
m/^\s*(\d+)\s*$/ || die "Unable to match: $_\n";
($xmap->{NC}, $rmode, $cc) = ($1, 1, 0);
if ($xmap->{NC} == 0) { $rmode = 2; }
}
elsif ($rmode == 1) { # Read class name maps
my @toks = split(/ +/);
my $cid = shift @toks;
${$xmap->{ClMap}}[$cid] = join(CCAT_CH, @toks);
if (++$cc == $xmap->{NC}) { $rmode = 2; }
}
elsif ($rmode == 2) { # Read number of features
m/^\s*(\d+)\s*$/ || die "Unable to match \@ $rmode: $_\n";
($xmap->{NV}, $rmode) = ($1, 3);
}
elsif ($rmode == 3) { # Read number of real-valued features
m/^\s*(\d+)\s*$/ || die "Unable to match \@ $rmode: $_\n";
($xmap->{NReal}, $rmode, $nr) = ($1, 4, 0);
if ($xmap->{NReal} == 0) { $rmode = 5; }
}
elsif ($rmode == 4) { # Read map of real-value features
my @toks = split(/ +/);
my $rid = shift @toks;
${$xmap->{VMap}}[$rid] = join(CCAT_CH, @toks);
if (++$nr == $xmap->{NReal}) { $rmode = 5; }
}
elsif ($rmode == 5) { # Read number of categorical features
m/^\s*(\d+)\s*$/ || die "Unable to match \@ $rmode: $_\n";
($xmap->{NCat}, $rmode, $nc) = ($1, 6, $nr);
}
elsif ($rmode == 6) { # Read categorical feature
my @toks = split(/ +/);
my $cid = shift @toks;
if (!HLPMAP) {
${$xmap->{VMap}}[$cid] = join(CCAT_CH, @toks); }
else {
my ($sch, $ech, $jch) = ('', '', '');
if ($#toks > 0) { ($sch, $ech, $jch) = ('\'', '\'', ' '); }
${$xmap->{VMap}}[$cid] = $sch . join($jch, @toks) . $ech;
}
$rmode = 7;
if (CCHK) { assert($cid == $nc, "Invalid categorical ID"); }
}
elsif ($rmode == 7) { # Read domain size of current feature
m/^\s*(\d+)\s*$/ || die "Unable to match \@ $rmode: $_\n";
($xmap->{CDs}->{$nc}, $rmode, $nv) = ($1, 8, 0);
}
elsif ($rmode == 8) { # Read values of categorical feature
my @toks = split(/ +/);
my $vid = shift @toks;
if (!HLPMAP) {
${$xmap->{CMap}->{$nc}}[$vid] = join(CCAT_CH, @toks); }
else {
my ($repl, $sch, $ech, $jch) = (0, '', '', '');
for (my $i=0; $i<=$#toks; ++$i) {
if ($toks[$i] =~ m/${$xmap->{VMap}}[$nc]/) {
$toks[$i] =~ s/${$xmap->{VMap}}[$nc]/\?\?/g;
$repl = 1;
}
}
if ($#toks > 0 && !$repl) { ($sch,$ech,$jch)=('\'','\'',' '); }
${$xmap->{CMap}->{$nc}}[$vid] = $sch . join($jch, @toks) . $ech;
}
if (++$nv == $xmap->{CDs}->{$nc}) {
if (++$nc == $xmap->{NReal}+$xmap->{NCat}) { $rmode = 9; }
else { $rmode = 6; }
}
}
else { die "Invalid state with input \@ $rmode: $_\n"; }
}
close($fh);
}
# Parse CNBC format -- currently hard-coded for 2 classes
sub parse_cnbc()
{
my ($opts, $cnbc, $fname) = @_;
open(my $fh, "<$fname") || die "Unable to open file $fname. " . F_ERR_MSG;
my ($cc, $cv, $pol, $rmode) = (0, 0, 0, 0);
while(<$fh>) {
chomp;
if ($rmode == 0) { # Read number of classes
m/^\s*(\d+)\s*$/ || die "Unable to match: $_\n";
($cnbc->{NC}, $rmode, $cc) = ($1, 1, 0);
}
elsif ($rmode == 1) { # Read priors
m/^\s*(\-?\d+\.?\d*)\s*$/ || die "Unable to match: $_\n";
push @{$cnbc->{Prior}}, $1;
if (++$cc == $cnbc->{NC}) { $rmode = 2; }
}
elsif ($rmode == 2) { # Read number of features
m/^\s*(\d+)\s*$/ || die "Unable to match: $_\n";
($cnbc->{NV}, $cv, $rmode) = ($1, 0, 3);
}
elsif ($rmode == 3) { # Read domain size of feature
my $cpt = "CPT$cv";
if ($cv == $cnbc->{NV}) { die "Too many features specified?\n"; }
m/^\s*(\d+)\s*$/ || die "Unable to match: $_\n";
($cnbc->{$cpt}->{D}, $cc, $rmode) = ($1, 0, 4);
}
elsif ($rmode == 4) { # Read CPT for feature
my $cpt = "CPT$cv";
my $ccl = "C$cc";
my @probs = split(/ +/);
if ($#probs+1 != $cnbc->{$cpt}->{D}) { die "Invalid CPT def\n"; }
for (my $i=0; $i<=$#probs; ++$i) {
$probs[$i] =~ m/(\-?\d+\.?\d*)/ || die "Unable to match: $_\n";
push @{$cnbc->{$cpt}->{$ccl}}, $probs[$i];
}
if (++$cc == $cnbc->{NC}) {
($cv, $cc, $rmode) = ($cv+1, 0, 3); # Move to next feature
}
} else { die "Unexpected read mode in CNBC file\n"; }
}
close($fh);
}
# Parse BLC format
sub parse_blc()
{
my ($opts, $blc, $fname) = @_;
open(my $fh, "<$fname") || die "Unable to open file $fname. " . F_ERR_MSG;
my ($rmode, $cnt) = (0, 0);
while(<$fh>) {
next if m/^\s*$/ || m/^c\s+/;
chomp;
if ($rmode == 0) {
m/\s*(\d+)\s*$/ || die "Unable to match: $_\n";
($blc->{NV}, $rmode) = ($1, 1);
}
elsif ($rmode == 1) {
if ($cnt == $blc->{NV}+1) {
die "Too many lines in BLC description??\n"; }
m/^\s*(\-?\d+\.?\d*)\s*$/ || die "Unable to match: $_\n";
${$blc->{Ws}}[$cnt++] = $1;
}
}
close($fh);
}
# Parse ACC format
sub parse_acc()
{
my ($opts, $acc, $fname) = @_;
open(my $fh, "<$fname") || die "Unable to open file $fname. " . F_ERR_MSG;
my ($cc, $cv, $pol, $rmode) = (0, 0, 0, 0);
while(<$fh>) {
next if m/^\s*$/ || m/^c\s+/;
chomp;
if ($rmode == 0) {
m/\s*(\d)\s*$/ || die "Unable to match: $_\n";
($acc->{NC}, $rmode) = ($1, 1);
}
elsif ($rmode == 1) {
m/\s*(\d+)\s*$/ || die "Unable to match: $_\n";
($acc->{NV}, $rmode) = ($1, 2);
}
elsif ($rmode == 2) {
my $class = "C$cc";
m/^\s*(\-?\d+\.?\d*)\s*$/ || die "Unable to match: $_\n";
$acc->{VV}->{$class}->{W0} = $1;
$rmode = 3;
}
elsif ($rmode == 3) {
my $class = "C$cc";
my $polarity = "P$pol";
m/^\s*(\-?\d+\.?\d*)\s*$/ || die "Unable to match: $_\n";
${$acc->{VV}->{$class}->{$polarity}}[$cv] = $1;
$pol = 1 - $pol;
if ($pol == 0) { $cv++; }
if ($cv == $acc->{NV}) {
($cc, $cv, $pol) = ($cc+1, 0, 0);
if ($cc == $acc->{NC}) { last; }
$rmode = 2;
}
}
}
close($fh);
}
# Parse instance format
sub parse_instance()
{
my ($opts, $inst, $fname) = @_;
open(my $fh, "<$fname") || die "Unable to open file $fname. " . F_ERR_MSG;
my ($cnt, $rmode) = (0, 0);
while(<$fh>) {
next if m/^\s*$/ || m/^c\s+/;
chomp;
if ($rmode == 0) {
m/\s*(\d+)\s*$/ || die "Unable to match: $_\n";
($inst->{NV}, $rmode) = ($1, 1);
}
elsif ($rmode == 1) {
m/\s*(\d+)\s*$/ || die "Unable to match: $_\n";
${$inst->{E}}[$cnt++] = $1;
if ($cnt == $inst->{NV}) { $rmode = 2; }
}
elsif ($rmode == 2) {
m/\s*(\d+)\s*$/ || die "Unable to match: $_\n";
$inst->{C} = $1;
}
}
close($fh);
}
# Parse explanations
sub parse_explanations()
{
my ($fname, $xpl) = @_;
open(my $fh, "<$fname") || die "Unable to open file $fname. " . F_ERR_MSG;
while(<$fh>) {
next if m/^\s*$/ || m/^c\s+/;
chomp;
my @lits = split(/ +/);
shift @lits; # Drop 'Expl: '
push @{$xpl->{Expl}}, \@lits;
}
close($fh);
}
END {
}
1; # to ensure that the 'require' or 'use' succeeds
package Utils;
use strict;
use warnings;
use Data::Dumper;
use POSIX;
use Exporter();
use Sys::Hostname;
BEGIN {
@Utils::ISA = ('Exporter');
@Utils::EXPORT_OK = qw( &get_progname &get_progpath &round &SIG_handler );
}
#------------------------------------------------------------------------------#
# Execution path handling
#------------------------------------------------------------------------------#
sub get_progname() {
my @progname_toks = split(/\//, $0);
my $progname = $progname_toks[$#progname_toks];
#print "$progname\n";
return $progname;
}
sub get_progpath() {
my @progname_toks = split(/\//, $0);
pop @progname_toks;
my $progpath = join('/', @progname_toks);
if ($progpath eq '') { $progpath = '\.\/'; }
#print "Prog Path: $progpath\n"; #exit;
return $progpath;
}
sub get_hostname() {
my $full_host_name = &Sys::Hostname::hostname();
$full_host_name =~ m/(\w+)\.?/;
my $rhname = $1;
#print "|$hostname|\n"; exit;
return $rhname;
}
sub resolve_inc() { # Kept here as a template; need a copy in each script...
my ($cref, $pmname) = @_;
my @progname_toks = split(/\//, $0);
pop @progname_toks;
my $progpath = join('/', @progname_toks);
my $fullname = $progpath . '/' . $pmname;
my $fh;
open($fh, "<$fullname") || die "non-existing file: $pmname\n";
return $fh;
}
#------------------------------------------------------------------------------#
# Signal handling utilities
#------------------------------------------------------------------------------#
sub register_handlers()
{
$SIG{'INT'} = 'Utils::INT_handler';
$SIG{'TERM'} = 'Utils::INT_handler';
$SIG{'ABRT'} = 'Utils::SIG_handler';
$SIG{'SEGV'} = 'Utils::SIG_handler';
$SIG{'BUS'} = 'Utils::SIG_handler';
$SIG{'QUIT'} = 'Utils::SIG_handler';
$SIG{'XCPU'} = 'Utils::SIG_handler';
}
my @args = ();
my @callback = ();
sub push_arg()
{
push @args, shift;
}
sub push_callback()
{
push @callback, shift;
}
sub SIG_handler()
{
&Utils::INT_handler();
}
sub INT_handler()
{
# call any declared callbacks, e.g. to prints stats, summaries, etc.
print "\nReceived system signal. Cleaning up & terminating...\n";
foreach my $cback (@callback) {
&{$cback}(\@args);
}
exit 20; # 20 denotes resources exceeded condition (see below)
}
#------------------------------------------------------------------------------#
# Useful utils
#------------------------------------------------------------------------------#
sub round() {
my ($rval) = @_;
return int($rval + 0.5);
}
END {
}
1; # to ensure that the 'require' or 'use' succeeds
package Writers;
use strict;
use warnings;
use Data::Dumper;
use POSIX;
use Exporter;
require Utils; # Must use require, to get INC updated
import Utils qw( &get_progname &get_progpath );
BEGIN {
@Writers::ISA = ('Exporter');
@Writers::EXPORT_OK = qw( &write_xlc );
}
# Export XLC format
sub write_xlc()
{
my ($opts, $xlc) = @_;
print("$xlc->{NV}\n");
print("$xlc->{W0}\n");
print("$xlc->{NReal}\n");
for (my $i=0; $i<$xlc->{NReal}; ++$i) {
print("${$xlc->{RVs}}[$i]\n");
}
print("$xlc->{NCat}\n");
for (my $i=0; $i<$xlc->{NCat}; ++$i) {
my $cvi = "CVs$i";
print("${$xlc->{CDs}}[$i] ");
print("@{$xlc->{$cvi}}\n");
}
}
END {
}
1; # to ensure that the 'require' or 'use' succeeds
#!/usr/bin/env perl
## Tool for translating the probabilities of an CNBC into a
## sequence of non-negative weights which are then represented
## in the XLC format.
## Script specifically assumes *2* classes
push @INC, \&resolve_inc;
use strict;
use warnings;
use Data::Dumper;
use Getopt::Std;
require Parsers;
import Parsers qw( parse_cnbc );
require Writers;
import Writers qw( write_xlc );
use constant DBG => 0; ## Also, comment out unused 'uses'
use constant CHK => 0;
my $f_err_msg = "Please check file name, existence, permissions, etc.\n";
# 0. Read command line arguments
my %opts = ();
&read_opts(\%opts);
if ((CHK || DBG) && (defined($opts{k}) || defined($opts{d}))) {
## Uncomment to use assertions && debug messages
#use Carp::Assert; # Assertions are on.
#if (DBG && $opts{d}) {
# use Data::Dumper;
#}
}
if (defined($opts{o})) {
open ($opts{FH}, '>', $opts{o});
select($opts{FH});
}
# 1. Data structures
my %cnbc = ();
my %xlc = ();
my $mval = 0;
my $tval = 0;
# 2. Read ML model (definition of (C)NBC in CNBC format)
&parse_cnbc(\%opts, \%cnbc, $opts{f});
if ($opts{d}) { warn Data::Dumper->Dump([ \%cnbc ], [ qw(cnbc) ]); }
# 3. Translate CNBC weights (i.e. probs) into CNBC weights (i.e. additive & >=0)
&process_weights(\%opts, \%cnbc);
if ($opts{d}) { warn Data::Dumper->Dump([ \%cnbc ], [ qw(cnbc) ]); }
#4. Reduce CNBC (w/ weights) into XLC
&reduce_cnbc_xlc(\%opts, \%cnbc, \%xlc);
if ($opts{d}) { warn Data::Dumper->Dump([ \%xlc ], [ qw(xlc) ]); }
# 4. Print ML model in ACC format
&write_xlc(\%opts, \%xlc);
1;
# Core functions
# Goal is to apply a translation to the prob values
sub process_weights()
{
my ($opts, $cnbc) = @_;
if (CHK && $opts->{k}) {
assert($cnbc->{NC}==2, "Cannot handle $cnbc->{NC} classes\n");
}
# 1. First traversal: compute & sum logarithms and flag 0 probs
my ($hasp0, $sumlogs, $minv, $logv) = (0, 0, 0, 0);
for(my $i=0; $i<=$#{$cnbc->{Prior}}; ++$i) {
if (${$cnbc->{Prior}}[$i] == 0) { $hasp0 = 1; }
else {
$logv = log(${$cnbc->{Prior}}[$i]);
$sumlogs += $logv;
${$cnbc->{Prior}}[$i] = $logv;
if ($logv < $minv) { $minv = $logv; }
}
}
for(my $j=0; $j<$cnbc->{NV}; ++$j) {
my $cpt = "CPT$j";
for(my $i=0; $i<=$#{$cnbc->{Prior}}; ++$i) {
my $ccl = "C$i";
for(my $k=0; $k<$cnbc->{$cpt}->{D}; ++$k) {
if (${$cnbc->{$cpt}->{$ccl}}[$k] == 0) { $hasp0 = 1; }
else {
$logv = log(${$cnbc->{$cpt}->{$ccl}}[$k]);
$sumlogs += $logv;
${$cnbc->{$cpt}->{$ccl}}[$k] = $logv;
if ($logv < $minv) { $minv = $logv; }
}
}
}
}
$mval = $sumlogs - 1;
$tval = ($hasp0) ? -$mval : -$minv;
# 2. Second traversal: update 0 probs, offset weights by T
for(my $i=0; $i<=$#{$cnbc->{Prior}}; ++$i) {
if (${$cnbc->{Prior}}[$i] == 0) {
${$cnbc->{Prior}}[$i] = $mval;
}
${$cnbc->{Prior}}[$i] += $tval;
}
for(my $j=0; $j<$cnbc->{NV}; ++$j) {
my $cpt = "CPT$j";
for(my $i=0; $i<=$#{$cnbc->{Prior}}; ++$i) {
my $ccl = "C$i";
for(my $k=0; $k<$cnbc->{$cpt}->{D}; ++$k) {
if (${$cnbc->{$cpt}->{$ccl}}[$k] == 0) {
${$cnbc->{$cpt}->{$ccl}}[$k] = $mval;
}
${$cnbc->{$cpt}->{$ccl}}[$k] += $tval;
}
}
}
if ($opts->{d}) { warn Data::Dumper->Dump([ $cnbc ], [ qw(cnbc_pw) ]); }
}
sub reduce_cnbc_xlc()
{
my ($opts, $cnbc, $xlc) = @_;
$xlc->{NV} = $cnbc->{NV};
$xlc->{W0} = ${$cnbc->{Prior}}[0] - ${$cnbc->{Prior}}[1];
$xlc->{NReal} = 0;
$xlc->{NCat} = $cnbc->{NV};
for(my $j=0; $j<$cnbc->{NV}; ++$j) {
my $cpt = "CPT$j";
my $cvj = "CVs$j";
my ($ccl0, $ccl1) = ('C0', 'C1');
push @{$xlc->{CDs}}, $cnbc->{$cpt}->{D};
for(my $k=0; $k<$cnbc->{$cpt}->{D}; ++$k) {
my $vdiff =
${$cnbc->{$cpt}->{$ccl0}}[$k] - ${$cnbc->{$cpt}->{$ccl1}}[$k];
push @{$xlc->{$cvj}}, $vdiff;
}
}
}
# Format parsing functions
sub read_acc_spec()
{
my ($fname, $acc) = @_;
die "Must use common parser!!!!\n";
open(my $fh, "<$fname") ||
die "Unable to open file $fname. " . $f_err_msg;
my ($cc, $cv, $pol, $rmode) = (0, 0, 0, 0);
while(<$fh>) {
chomp;
if ($rmode == 0) {
m/\s*(\d)\s*$/ || die "Unable to match: $_\n";
($acc->{NC}, $rmode) = ($1, 1);
}
elsif ($rmode == 1) {
m/\s*(\d+)\s*$/ || die "Unable to match: $_\n";
($acc->{NV}, $rmode) = ($1, 2);
}
elsif ($rmode == 2) {
my $class = "C$cc";
m/\s*(\-?\d+\.?\d*)\s*$/ || die "Unable to match: $_\n";
$acc->{VV}->{$class}->{W0} = $1;
$rmode = 3;
}
elsif ($rmode == 3) {
my $class = "C$cc";
my $polarity = "P$pol";
m/\s*(\-?\d+\.?\d*)\s*$/ || die "Unable to match: $_\n";
${$acc->{VV}->{$class}->{$polarity}}[$cv] = $1;
$pol = 1 - $pol;
if ($pol == 0) { $cv++; }
if ($cv == $acc->{NV}) {
($cc, $cv, $pol) = ($cc+1, 0, 0);
if ($cc == $acc->{NC}) { last; }
$rmode = 2;
}
} else { die "Unexpected line in file: $_\n"; }
}
close($fh);
}
# Utilities
sub read_opts()
{
my ($opts) = @_;
getopts("hdvkf:o:", $opts);
if ($opts->{h}) {
&prt_help();
}
elsif (!defined($opts->{f})) {
die "Usage: $0 [-h] [-d] [-v] [-k] [-o <out-file>] -f <cnbc-file>\n" ;
}
}
sub prt_help()
{
my $tname = &toolname($0);
print <<"EOF";
$tname: Translate CNBC format into XLC format
Usage: $tname [-h] [-d] [-v] [-k] [-o <out-file>] -f <cnbc-file>
-f <cnbc-file> specification of CNBC file
-o <out-file> output file for exporting XLC format
-k perform consistency checks & exit if error
-v verbose mode
-d debug mode
-h prints this help
Author: joao.marques-silva\@univ-toulouse.fr
EOF
exit();
}
sub toolname()
{
my ($tname) = @_;
$tname =~ m/([\.\_\-a-zA-Z0-9]+)$/;
return $1;
}
#------------------------------------------------------------------------------#
# Auxiliary functions
#------------------------------------------------------------------------------#
sub resolve_inc() { # Copy from template kept in UTILS package
my ($cref, $pmname) = @_;
my @progname_toks = split(/\//, $0);
pop @progname_toks;
my $progpath = join('/', @progname_toks);
my $fullname = $progpath . '/' . $pmname;
open(my $fh, "<$fullname") || die "non-existing file: $pmname\n";
return $fh;
}
# jpms
import argparse
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import CategoricalNB
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
import pickle
import os
import numpy as np
from scipy.special import logsumexp
def predict_proba(X, clf, precision=None):
if precision == None:
feature_priors = clf.feature_log_prob_
class_priors = clf.class_log_prior_
else:
feature_priors = list(map(lambda x: np.log(np.clip(np.round(np.exp(x), precision), 1e-12, None)), clf.feature_log_prob_))
class_priors = list(map(lambda x: np.log(np.clip(np.round(np.exp(x), precision), 1e-12, None)), clf.class_log_prior_))
jll = np.zeros((X.shape[0], 2))
for i in range(X.shape[1]):
indices = X.values[:, i]
jll += feature_priors[i][:, indices].T
total_ll = jll + class_priors
log_prob_x = logsumexp(total_ll, axis=1)
return np.argmax(np.exp(total_ll - np.atleast_2d(log_prob_x).T), axis=1)
if __name__ == "__main__":
parser = argparse.ArgumentParser('Categorical NBC generator.')
parser.add_argument('-d', type=str, help="dataset path")
parser.add_argument('-op', type=str, help="output pickle classifier path", default="")
parser.add_argument('-oc', type=str, help="output NBC classifier path", default="")
parser.add_argument('-oi', type=str, help="output inst path", default="")
parser.add_argument('-ox', type=str, help="output xmap path", default="")
parser.add_argument('-v', type=int, help="verbose", default=0)
parser.add_argument('-p', type=int, help="precision of classifier", default=None)
args = parser.parse_args()
df = pd.read_csv(args.d)
df.columns = [s.strip() for s in df.columns.values]
encoders = dict()
min_categories = dict()
for column in df.columns:
if df[column].apply(type).eq(str).all():
df[column] = df[column].str.strip()
enc = LabelEncoder()
enc.fit(df[column])
df[column] = enc.transform(df[column])
min_categories[column] = len(enc.classes_)
encoders[column] = enc
X = df.drop(df.columns[-1], axis=1)
y = df[df.columns[-1]]
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, random_state=0)
clf = CategoricalNB(min_categories=np.array(list(min_categories.values())).astype(int)[:-1])
clf.fit(X_train, y_train)
if args.v:
print("----------------------")
print("Initial accuracy:")
print("Train accuracy: ", accuracy_score(clf.predict(X_train), y_train))
print("Test accuracy: ", accuracy_score(clf.predict(X_test), y_test))
print("----------------------")
if args.p is not None:
print("----------------------")
print("Rounded accuracy (precision=" + str(args.p) + "):")
print("Train accuracy: ", accuracy_score(predict_proba(X_train, clf, args.p), y_train))
print("Test accuracy: ", accuracy_score(predict_proba(X_test, clf, args.p), y_test))
print("----------------------")
if args.ox:
if not os.path.exists(os.path.dirname(args.ox)):
os.makedirs(os.path.dirname(args.ox))
with open(args.ox, "w") as f:
# --------- Target -----------
enc = encoders[y.name]
C = len(enc.classes_)
f.write(str(C) + "\n")
for category, target in zip(enc.classes_, enc.transform(enc.classes_)):
f.write(str(target) + " " + str(category) + "\n")
# --------- Features ---------
n = X.shape[1]
f.write(str(n) + "\n")
f.write("0" + "\n")
f.write(str(n) + "\n")
for i, feature in enumerate(X.columns):
f.write(str(i) + " " + str(feature) + "\n")
enc = encoders[feature]
f.write(str(len(enc.classes_)) + "\n")
for category, label in zip(enc.classes_, enc.transform(enc.classes_)):
f.write(str(label) + " " + str(category) + "\n")
"""
FUTURE DEVELOPMENT
# Get types of features (categorical or continuous (=real-valued))
dtypes = dict()
for column in X.columns:
if len(X[column].unique()) < (X.shape[0] / 3):
dtypes[column] = "categorical"
else:
dtypes[column] = "continuous"
# Real-valued features
f.write(str(len(dict((k, v) for k, v in dtypes.items() if v == "continuous"))) + "\n")
for i, (feature, dtype) in enumerate(dtypes.items()):
if dtype == "continuous":
f.write(str(i) + " " + str(feature) + "\n")
enc = encoders[feature]
f.write(str(len(enc.classes_)) + "\n")
for category, label in zip(enc.classes_, enc.transform(enc.classes_)):
f.write(str(label) + " " + str(category) + "\n")
# Categorical features
f.write(str(len(dict((k, v) for k, v in dtypes.items() if v == "categorical"))) + "\n")
for i, (feature, dtype) in enumerate(dtypes.items()):
if dtype == "categorical":
f.write(str(i) + " " + str(feature) + "\n")
enc = encoders[feature]
f.write(str(len(enc.classes_)) + "\n")
for category, label in zip(enc.classes_, enc.transform(enc.classes_)):
f.write(str(label) + " " + str(category) + "\n")
"""
if args.op:
if not os.path.exists(os.path.dirname(args.op)):
os.makedirs(os.path.dirname(args.op))
pickle.dump(clf, open(args.op, "wb"))
if args.oc:
if not os.path.exists(os.path.dirname(args.oc)):
os.makedirs(os.path.dirname(args.oc))
with open(args.oc, "w") as f:
n = len(clf.classes_)
f.write(str(n) + "\n")
class_priors = np.exp(clf.class_log_prior_)
for i in class_priors:
if args.p is not None:
f.write(str(np.round(np.format_float_positional(i, trim='-'), args.p)) + "\n")
else:
f.write(str(np.format_float_positional(i, trim='-')) + "\n")
m = X.shape[1]
f.write(str(m) + "\n")
feature_log_priors = clf.feature_log_prob_
for feature_log_prior in feature_log_priors:
feature_prior = np.exp(feature_log_prior)
f.write(str(feature_prior.shape[1]) + "\n")
for feature_class_prior in feature_prior:
for v in feature_class_prior:
if args.p is not None:
f.write(str(np.round(np.format_float_positional(v, trim='-'), args.p)) + " ")
else:
f.write(str(np.format_float_positional(v, trim='-')) + " ")
f.write("\n")
if args.oi:
if not os.path.exists(os.path.dirname(args.oi)):
os.makedirs(os.path.dirname(args.oi))
name = next(s for s in reversed(args.oi.split("/")) if s)
for i, (_, sample) in enumerate(X.iterrows()):
path = os.path.join(args.oi, name + "." + str(i+1) + ".txt")
with open(path, "w") as f:
f.write(str(len(sample)) + "\n")
for value in sample:
f.write(str(value) + "\n")
f.write(str(clf.predict([sample])[0]) + "\n")
print "Called::\n";
my $f_err_msg = "Please check file name, existence, permissions, etc.\n";
This diff is collapsed.
......@@ -3,6 +3,8 @@ import dash_bootstrap_components as dbc
import dash_daq as daq
from pages.application.DecisionTree.DecisionTreeComponent import DecisionTreeComponent
from pages.application.NaiveBayes.NaiveBayesComponent import NaiveBayesComponent
import subprocess
class Application():
def __init__(self, view):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment