diff --git a/.gitignore b/.gitignore index ed1537f039bd0b1e28e3d3fd73f3cae685e3e067..f7db4b641de813a678b90840deee40b361228a41 100644 --- a/.gitignore +++ b/.gitignore @@ -3,9 +3,4 @@ __pycache__ pages/application/DecisionTree/utils/__pycache__ pages/application/DecisionTree/__pycache__ pages/application/__pycache__ -decision_tree_classifier_20170212.pkl -push_command -adult.pkl -adult_data_00000.inst -iris_00000.txt -tests \ No newline at end of file +tests/push_command \ No newline at end of file diff --git a/README.md b/README.md index 9f48666b427cdfa4ea4f9d3f92eb9c512a386643..73cd73236adf1dd8aff083939e563d94c7fe862d 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,18 @@ # Pnria Projet Deeplever -## Requirements +## How to use it + +Set the parameters and upload the models. +The instance should either be a .txt at format (feature1=...,feature2=...) or a json file + +You will find possible models in the tests file + +## Trying the app + +### Requirements `pip install -r requirements.txt` -## Running +### Running Run app.py then visit localhost diff --git a/assets/header.css b/assets/header.css index b7dcabbc23dda219222a7fe465f8efa167df437e..3a882eae1b7fdffd40dd37f87941a5508a4f29d1 100644 --- a/assets/header.css +++ b/assets/header.css @@ -36,6 +36,12 @@ div.sidebar.col-3 { background-color:gray; } +.sidebar .check-boxes{ + width: 100%; + height: 40px; + text-align: center; +} + .sidebar .upload { width: 100%; height: 50px; @@ -49,7 +55,6 @@ div.sidebar.col-3 { .sidebar .Select-control { width: 100%; - height: 30px; line-height: 30px; border-width: 1px; border-radius: 5px; @@ -60,7 +65,7 @@ div.sidebar.col-3 { background-color: rgb(26,26,26); } -.sidebar .sidebar-dropdown{ +.sidebar .dropdown{ width: 100%; height: 30px; line-height: 30px; diff --git a/callbacks.py b/callbacks.py index 53779188ee9bd9b4de03049a0058529ba89d4787..5ef62d601526f58dbf4d39f45ca698ee867d80ad 100644 --- a/callbacks.py +++ b/callbacks.py @@ -4,7 +4,7 @@ from dash import Input, Output, State from dash.dependencies import Input, Output, State from dash.exceptions import PreventUpdate -from utils import parse_contents_graph, parse_contents_instance +from utils import parse_contents_graph, parse_contents_instance, parse_contents_data def register_callbacks(page_home, page_course, page_application, app): @@ -31,12 +31,15 @@ def register_callbacks(page_home, page_course, page_application, app): @app.callback( Output('pretrained_model_filename', 'children'), + Output('info_filename', 'children'), Output('instance_filename', 'children'), Output('graph', 'children'), Output('explanation', 'children'), Input('ml_model_choice', 'value'), Input('ml_pretrained_model_choice', 'contents'), State('ml_pretrained_model_choice', 'filename'), + Input('model_info_choice', 'contents'), + State('model_info_choice', 'filename'), Input('ml_instance_choice', 'contents'), State('ml_instance_choice', 'filename'), Input('number_explanations', 'value'), @@ -45,55 +48,73 @@ def register_callbacks(page_home, page_course, page_application, app): Input('expl_choice', 'value'), prevent_initial_call=True ) - def update_ml_type(value_ml_model, pretrained_model_contents, pretrained_model_filename, instance_contents, instance_filename, enum, xtype, solver, expl_choice): + def update_ml_type(value_ml_model, pretrained_model_contents, pretrained_model_filename, model_info, model_info_filename, \ + instance_contents, instance_filename, enum, xtype, solver, expl_choice): ctx = dash.callback_context if ctx.triggered: ihm_id = ctx.triggered[0]['prop_id'].split('.')[0] model_application = page_application.model + + # Choice of model if ihm_id == 'ml_model_choice' : - model_application.update_ml_model(value_ml_model) - return None, None, None, None + model_application.update_ml_model(value_ml_model) + return None, None, None, None, None + # Choice of pkl pretrained model elif ihm_id == 'ml_pretrained_model_choice': - if value_ml_model is None : + if model_application.ml_model is None : + raise PreventUpdate + graph = parse_contents_graph(pretrained_model_contents, pretrained_model_filename) + model_application.update_pretrained_model(graph) + if not model_application.add_info : + model_application.update_pretrained_model_layout() + return pretrained_model_filename, None, None, model_application.component.network, None + else : + return pretrained_model_filename, None, None, None, None + + # Choice of information for the model + elif ihm_id == 'model_info_choice': + if model_application.ml_model is None : raise PreventUpdate - tree, typ = parse_contents_graph(pretrained_model_contents, pretrained_model_filename) - model_application.update_pretrained_model(tree, typ) - return pretrained_model_filename, None, model_application.component.network, None + model_info = parse_contents_data(model_info, model_info_filename) + model_application.update_pretrained_model_layout_with_info(model_info, model_info_filename) + return pretrained_model_filename, model_info_filename, None, model_application.component.network, None + # Choice of instance to explain elif ihm_id == 'ml_instance_choice' : - if value_ml_model is None or pretrained_model_contents is None or enum is None or xtype is None: + if model_application.ml_model is None or model_application.pretrained_model is None or model_application.enum<=0 or model_application.xtype is None : raise PreventUpdate instance = parse_contents_instance(instance_contents, instance_filename) - model_application.update_instance(instance, enum, xtype) - return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation + model_application.update_instance(instance) + return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation + # Choice of number of expls elif ihm_id == 'number_explanations' : - if value_ml_model is None or pretrained_model_contents is None or instance_contents is None or xtype is None: + if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.xtype is None: raise PreventUpdate - instance = parse_contents_instance(instance_contents, instance_filename) - model_application.update_instance(instance, enum, xtype) - return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation + model_application.update_enum(enum) + return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation + # Choice of AxP or CxP elif ihm_id == 'explanation_type' : - if value_ml_model is None or pretrained_model_contents is None or instance_contents is None or enum is None : + if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.enum<=0 : raise PreventUpdate - instance = parse_contents_instance(instance_contents, instance_filename) - model_application.update_instance(instance, enum, xtype) - return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation + model_application.update_xtype(xtype) + return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation + # Choice of solver elif ihm_id == 'solver_sat' : - if value_ml_model is None or pretrained_model_contents is None or instance_contents is None or enum is None or xtype is None: + if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.enum<=0 or len(model_application.xtype)==0: raise PreventUpdate - instance = parse_contents_instance(instance_contents, instance_filename) - model_application.update_instance(instance, enum, xtype, solver=solver) - return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation + model_application.update_solver(solver) + return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation + # Choice of AxP to draw elif ihm_id == 'expl_choice' : - if instance_contents is None : + if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.enum<=0 or len(model_application.xtype)==0: raise PreventUpdate model_application.update_expl(expl_choice) - return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation + return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation @app.callback( @@ -116,3 +137,16 @@ def register_callbacks(page_home, page_course, page_application, app): for i in range (len(model_application.list_expls)): options[str(model_application.list_expls[i])] = model_application.list_expls[i] return False, False, False, options + + @app.callback( + Output('choice_info_div', 'hidden'), + Input('add_info_model_choice', 'on'), + prevent_initial_call=True + ) + def add_model_info(add_info_model_choice): + model_application = page_application.model + model_application.update_info_needed(add_info_model_choice) + if add_info_model_choice: + return False + else : + return True diff --git a/callbacks_detached.py b/callbacks_detached.py new file mode 100644 index 0000000000000000000000000000000000000000..a8db16ff713479399892471d56224bd0cd9456b2 --- /dev/null +++ b/callbacks_detached.py @@ -0,0 +1,172 @@ +import dash +import pandas as pd +from dash import Input, Output, State +from dash.dependencies import Input, Output, State +from dash.exceptions import PreventUpdate + +from utils import parse_contents_graph, parse_contents_instance, parse_contents_data + + +def register_callbacks(page_home, page_course, page_application, app): + page_list = ['home', 'course', 'application'] + + @app.callback( + Output('page-content', 'children'), + Input('url', 'pathname')) + def display_page(pathname): + if pathname == '/': + return page_home + if pathname == '/application': + return page_application.view.layout + if pathname == '/course': + return page_course + + @app.callback(Output('home-link', 'active'), + Output('course-link', 'active'), + Output('application-link', 'active'), + Input('url', 'pathname')) + def navbar_state(pathname): + active_link = ([pathname == f'/{i}' for i in page_list]) + return active_link[0], active_link[1], active_link[2] + + @app.callback( + Output('graph', 'children'), + Input('ml_model_choice', 'value'), + prevent_initial_call=True + ) + def update_ml_type(value_ml_model): + model_application = page_application.model + model_application.update_ml_model(value_ml_model) + return None + + @app.callback( + Output('pretrained_model_filename', 'children'), + Output('graph', 'children'), + Input('ml_pretrained_model_choice', 'contents'), + State('ml_pretrained_model_choice', 'filename'), + prevent_initial_call=True + ) + def update_ml_pretrained_model(pretrained_model_contents, pretrained_model_filename): + model_application = page_application.model + if model_application.ml_model is None : + raise PreventUpdate + graph = parse_contents_graph(pretrained_model_contents, pretrained_model_filename) + model_application.update_pretrained_model(graph) + if not model_application.add_info : + model_application.update_pretrained_model_layout() + return pretrained_model_filename, model_application.component.network + else : + return pretrained_model_filename, None + + @app.callback( + Output('info_filename', 'children'), + Output('graph', 'children'), + Input('model_info_choice', 'contents'), + State('model_info_choice', 'filename'), + prevent_initial_call=True + ) + def update_info_model(model_info, model_info_filename): + model_application = page_application.model + if model_application.ml_model is None : + raise PreventUpdate + model_info = parse_contents_data(model_info, model_info_filename) + model_application.update_pretrained_model_layout_with_info(model_info, model_info_filename) + return model_info_filename, model_application.component.network + + @app.callback( + Output('instance_filename', 'children'), + Output('graph', 'children'), + Output('explanation', 'children'), + Input('ml_instance_choice', 'contents'), + State('ml_instance_choice', 'filename'), + prevent_initial_call=True + ) + def update_instance(instance_contents, instance_filename): + model_application = page_application.model + if model_application.ml_model is None or model_application.pretrained_model is None or model_application.enum<=0 or model_application.xtype is None : + raise PreventUpdate + instance = parse_contents_instance(instance_contents, instance_filename) + model_application.update_instance(instance) + return instance_filename, model_application.component.network, model_application.component.explanation + + @app.callback( + Output('explanation', 'children'), + Input('number_explanations', 'value'), + prevent_initial_call=True + ) + def update_enum(enum): + model_application = page_application.model + if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.xtype is None: + raise PreventUpdate + model_application.update_enum(enum) + return model_application.component.explanation + + @app.callback( + Output('explanation', 'children'), + Input('explanation_type', 'value'), + prevent_initial_call=True + ) + def update_xtype(xtype): + model_application = page_application.model + if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.enum<=0 : + raise PreventUpdate + model_application.update_xtype(xtype) + return model_application.component.explanation + + @app.callback( + Output('explanation', 'children'), + Input('solver_sat', 'value'), + prevent_initial_call=True +) + def update_solver(solver): + model_application = page_application.model + if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.enum<=0 or len(model_application.xtype)==0: + raise PreventUpdate + model_application.update_solver(solver) + return model_application.component.explanation + + @app.callback( + Output('graph', 'children'), + Input('expl_choice', 'value'), + prevent_initial_call=True + ) + def update_expl_choice( expl_choice): + model_application = page_application.model + if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.enum<=0 or len(model_application.xtype)==0: + raise PreventUpdate + model_application.update_expl(expl_choice) + return model_application.component.network + + @app.callback( + Output('explanation', 'hidden'), + Output('navigate_label', 'hidden'), + Output('navigate_dropdown', 'hidden'), + Output('expl_choice', 'options'), + Input('explanation', 'children'), + Input('explanation_type', 'value'), + prevent_initial_call=True + ) + def layout_buttons_navigate_expls(explanation, explanation_type): + if explanation is None or len(explanation_type)==0: + return True, True, True, {} + elif "AXp" not in explanation_type and "CXp" in explanation_type: + return False, True, True, {} + else : + options = {} + model_application = page_application.model + for i in range (len(model_application.list_expls)): + options[str(model_application.list_expls[i])] = model_application.list_expls[i] + return False, False, False, options + + @app.callback( + Output('choice_info_div', 'hidden'), + Input('add_info_model_choice', 'on'), + prevent_initial_call=True + ) + def add_model_info(add_info_model_choice): + model_application = page_application.model + model_application.update_info_needed(add_info_model_choice) + if add_info_model_choice: + return False + else : + return True diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py index 662884f4a755556a110e97b99456adaf860b100e..8a56c32349fced19e8c03889bb26fdc176b79973 100644 --- a/pages/application/DecisionTree/DecisionTreeComponent.py +++ b/pages/application/DecisionTree/DecisionTreeComponent.py @@ -1,40 +1,139 @@ from os import path +import base64 import dash_bootstrap_components as dbc import dash_interactive_graphviz import numpy as np from dash import dcc, html +from pages.application.DecisionTree.utils.upload_tree import UploadedDecisionTree +from pages.application.DecisionTree.utils.data import Data from pages.application.DecisionTree.utils.dtree import DecisionTree + from pages.application.DecisionTree.utils.dtviz import (visualize, visualize_expl, visualize_instance) - class DecisionTreeComponent(): - def __init__(self, tree, typ_data): + def __init__(self, tree, type_tree='SKL', info=None, type_info=''): + + if info is not None and '.csv' in type_info: + self.categorical = True + data = Data(info) + fvmap = data.mapping_features() + feature_names = data.names[:-1] + self.uploaded_dt = UploadedDecisionTree(tree, type_tree, maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_) + self.dt_format, self.map, features_names_mapping = self.uploaded_dt.dump(fvmap, feat_names=feature_names) - self.dt = DecisionTree(from_pickle = tree) + elif info is not None and '.txt' in type_info : + self.categorical = True + fvmap = {} + feature_names = [] + for i,line in enumerate(info.split('\n')): + fid, TYPE = line.split(',')[:2] + dom = line.split(',')[2:] + assert (fid not in feature_names) + feature_names.append(fid) + assert (TYPE in ['Binary', 'Categorical']) + fvmap[f'f{i}'] = dict() + dom = sorted(dom) + for j,v in enumerate(dom): + fvmap[f'f{i}'][j] = (fid, True, v) + self.uploaded_dt = UploadedDecisionTree(tree, type_tree, maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_) + self.dt_format, self.map, features_names_mapping = self.uploaded_dt.dump(fvmap, feat_names=feature_names) + else : + self.categorical = False + try: + feature_names = tree.feature_names_in_ + except: + feature_names = [f'f{i}' for i in range(tree.n_features_in_)] + self.uploaded_dt = UploadedDecisionTree(tree, type_tree, maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_) + self.dt_format, self.map, features_names_mapping = self.uploaded_dt.convert_dt(feat_names=feature_names) + + self.mapping_instance = self.create_fvmap_inverse(features_names_mapping) + self.dt = DecisionTree(from_dt=self.dt_format, mapfile = self.map, feature_names = feature_names) dot_source = visualize(self.dt) - self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz(dot_source=dot_source, style = {"width": "60%", + self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz(dot_source=dot_source, style = {"width": "60%", "height": "90%", - "background-color": "transparent"}))] + "background-color": "transparent"})]) self.explanation = [] + + def create_fvmap_inverse(self, instance): + def create_fvmap_inverse_with_info(features_names_mapping) : + mapping_instance = {} + for feat in features_names_mapping : + feat_dic = {} + feature_description = feat.split(',') + name_feat, id_feat = feature_description[1].split(':') + + for mapping in feature_description[2:]: + real_value, mapped_value = mapping.split(':') + feat_dic[np.float32(real_value)] = int(mapped_value) + mapping_instance[name_feat] = feat_dic + + return mapping_instance + + def create_fvmap_inverse_threashold(features_names_mapping) : + mapping_instance = {} + for feat in features_names_mapping : + feature_description = feat.split(',') + name_feat, id_feat = feature_description[1].split(':') + mapping_instance[name_feat] = float(feature_description[2].split(':')[0]) + + return mapping_instance + + if self.categorical : + return create_fvmap_inverse_with_info(instance) + else : + return create_fvmap_inverse_threashold(instance) + + + def translate_instance(self, instance): + def translate_instance_categorical(instance): + instance_translated = [] + for feat, real_value in instance : + instance_translated.append((feat, self.mapping_instance[feat][real_value])) + return instance_translated + + def translate_instance_threasholds(instance): + instance_translated = [] + for feat, real_value in instance : + try: + if real_value <= self.mapping_instance[feat]: + instance_translated.append((feat, 0)) + else : + instance_translated.append((feat, 1)) + except: + instance_translated.append((feat, real_value)) + return instance_translated + + if self.categorical : + return translate_instance_categorical(instance) + else : + return translate_instance_threasholds(instance) + + def update_with_explicability(self, instance, enum, xtype, solver) : - dot_source = visualize_instance(self.dt, instance) - self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz( + + instance_translated = self.translate_instance(instance) + self.explanation = [] + list_explanations_path=[] + explanation = self.dt.explain(instance_translated, enum=enum, xtype = xtype, solver=solver) + + dot_source = visualize_instance(self.dt, instance_translated) + self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( dot_source=dot_source, style = {"width": "50%", "height": "80%", "background-color": "transparent"} - ))] + )]) - self.explanation = [] - list_explanations_path=[] - explanation = self.dt.explain(instance, enum=enum, xtype = xtype, solver=solver) #Creating a clean and nice text component + #instance plotting + self.explanation.append(html.H4("Instance : \n")) + self.explanation.append(html.P(str([str(instance[i]) for i in range (len(instance))]))) for k in explanation.keys() : if k != "List of path explanation(s)": if k in ["List of abductive explanation(s)","List of contrastive explanation(s)"] : @@ -51,9 +150,10 @@ class DecisionTreeComponent(): return list_explanations_path def draw_explanation(self, instance, expl) : + instance = self.translate_instance(instance) dot_source = visualize_expl(self.dt, instance, expl) - self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz( + self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( dot_source=dot_source, style = {"width": "50%", "height": "80%", - "background-color": "transparent"}))] + "background-color": "transparent"})]) diff --git a/pages/application/DecisionTree/utils/data.py b/pages/application/DecisionTree/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..91aded54aaa39c6239d3a0696beafb217dc9a8a4 --- /dev/null +++ b/pages/application/DecisionTree/utils/data.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +#-*- coding:utf-8 -*- +## +## data.py +## +## Created on: Sep 20, 2017 +## Author: Alexey Ignatiev, Nina Narodytska +## E-mail: aignatiev@ciencias.ulisboa.pt, narodytska@vmware.com +## + +# +#============================================================================== +from __future__ import print_function +import collections +import itertools +import pickle +import six +import gzip +from six.moves import range +import numpy as np +import pandas as pd + +# +#============================================================================== +class Data(object): + """ + Class for representing data (transactions). + """ + + def __init__(self, data, separator=','): + """ + Constructor and parser. + """ + self.names = None + self.nm2id = None + self.feats = None + self.targets = None + self.samples = None + + self.parse(data, separator) + + def parse(self, data, separator): + """ + Parse input file. + """ + + # reading data set from file + lines = data.split('\n') + + # reading preamble + self.names = [name.replace('"','').strip() for name in lines[0].strip().split(separator)] + self.feats = [set([]) for n in self.names[:-1]] + self.targets = set([]) + + lines = lines[1:] + + # filling name to id mapping + self.nm2id = {name: i for i, name in enumerate(self.names)} + + self.nonbin2bin = {} + for name in self.nm2id: + spl = name.rsplit(':',1) + if (spl[0] not in self.nonbin2bin): + self.nonbin2bin[spl[0]] = [name] + else: + self.nonbin2bin[spl[0]].append(name) + + # reading training samples + self.samples = [] + + for line, w in six.iteritems(collections.Counter(lines)): + inst = [v.strip() for v in line.strip().split(separator)] + self.samples.append(inst) + for i, v in enumerate(inst[:-1]): + if v: + self.feats[i].add(str(v)) + assert(inst[-1]) + self.targets.add(str(inst[-1])) + + self.nof_feats = len(self.names[:-1]) + + def mapping_features(self): + """ + feature-value mapping + """ + fvmap = {} + + for i in range(self.nof_feats): + fvmap[f'f{i}'] = dict() + for j, v in enumerate(sorted(self.feats[i])): + fvmap[f'f{i}'][j] = (self.names[i], True, v) + + if len(self.feats[i]) > 2: + m = len(self.feats[i]) + for j, v in enumerate(sorted(self.feats[i])): + fvmap[f'f{i}'][j+m] = (self.names[i], False, v) + + return fvmap diff --git a/pages/application/DecisionTree/utils/dtree.py b/pages/application/DecisionTree/utils/dtree.py index 6bd118dfd8cb6ca577338d8a8189cb00cea33a68..79088c2cce1cbb0c4c6e65183195a4a110f357f6 100644 --- a/pages/application/DecisionTree/utils/dtree.py +++ b/pages/application/DecisionTree/utils/dtree.py @@ -11,47 +11,32 @@ # #============================================================================== from __future__ import print_function - import collections from functools import reduce - -import sklearn from pysat.card import * from pysat.examples.hitman import Hitman from pysat.formula import CNF, IDPool from pysat.solvers import Solver -from torch import threshold try: # for Python2 from cStringIO import StringIO except ImportError: # for Python3 from io import StringIO - -import numpy as np -from dash import dcc, html from sklearn.tree import _tree +import numpy as np - -# -#============================================================================== class Node(): """ Node class. """ - def __init__(self, feat='', vals=None, threshold=None, children_left= None, children_right=None): + def __init__(self, feat='', vals=[]): """ Constructor. """ self.feat = feat - if threshold is not None : - self.threshold = threshold - self.children_left = 0 - self.children_right = 0 - else : - self.vals = {} - + self.vals = vals # #============================================================================== @@ -60,13 +45,12 @@ class DecisionTree(): Simple decision tree class. """ - def __init__(self, from_pickle=None, verbose=0): + def __init__(self, from_dt=None, mapfile=None, feature_names=None, verbose=0): """ Constructor. """ self.verbose = verbose - self.typ="" self.nof_nodes = 0 self.nof_terms = 0 @@ -76,112 +60,266 @@ class DecisionTree(): self.paths = {} self.feats = [] self.feids = {} + self.fdoms = {} + self.fvmap = {} + self.feature_names = {f'f{i}' : feature_names[i] for i, f in enumerate(feature_names)} - if from_pickle: - self.typ="pkl" - self.tree_ = '' - self.from_pickle_file(from_pickle) - - #problem de feature names et problem de vals dans node - def from_pickle_file(self, tree): - #help(_tree.Tree) - self.tree_ = tree.tree_ - #print(sklearn.tree.export_text(tree)) - try: - feature_names = tree.feature_names_in_ - except: - print("You did not dump the model with the features names") - feature_names = [str(i) for i in range(tree.n_features_in_)] - - class_names = tree.classes_ - self.nodes = collections.defaultdict(lambda: Node(feat='', threshold=int(0), children_left=int(0), children_right=int(0))) - self.terms={} - self.nof_nodes = self.tree_.node_count - self.root_node = 0 - self.feats = feature_names - - feature_name = [ - feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" - for i in self.tree_.feature] - - def recurse(node): - if self.tree_.feature[node] != _tree.TREE_UNDEFINED: - name = feature_name[node] - val = self.tree_.threshold[node] - - #faire une boucle for des vals ? - self.nodes[int(node)].feat = name - self.nodes[int(node)].threshold = np.round(val, 4) - self.nodes[int(node)].children_left = int(self.tree_.children_left[node]) - self.nodes[int(node)].children_right = int(self.tree_.children_right[node]) - - recurse(self.tree_.children_left[node]) - recurse(self.tree_.children_right[node]) + # OHE mapping + OHEMap = collections.namedtuple('OHEMap', ['dir', 'opp']) + self.ohmap = OHEMap(dir={}, opp={}) - else: - self.terms[node] = class_names[np.argmax(self.tree_.value[node])] - - recurse(self.root_node) + if from_dt: + self.from_dt(from_dt) + + if mapfile: + self.parse_mapping(mapfile) + else: # no mapping is given + for f in self.feats: + for v in self.fdoms[f]: + self.fvmap[tuple([f, v])] = '{0}={1}'.format(f, v) + + def from_dt(self, data): + """ + Get the tree from a file pointer. + """ + + contents = StringIO(data) + + lines = contents.readlines() + + # filtering out comment lines (those that start with '#') + lines = list(filter(lambda l: not l.startswith('#'), lines)) + # number of nodes + self.nof_nodes = int(lines[0].strip()) + + # root node + self.root_node = int(lines[1].strip()) + + # number of terminal nodes (classes) + self.nof_terms = len(lines[3][2:].strip().split()) + + # the ordered list of terminal nodes + self.terms = {} + for i in range(self.nof_terms): + nd, _, t = lines[i + 4].strip().split() + self.terms[int(nd)] = t #int(t) + + # finally, reading the nodes + self.nodes = collections.defaultdict(lambda: Node(feat='', vals={})) + self.feats = set([]) + self.fdoms = collections.defaultdict(lambda: set([])) + for line in lines[(4 + self.nof_terms):]: + # reading the tuple + nid, fid, fval, child = line.strip().split() + + # inserting it in the nodes list + self.nodes[int(nid)].feat = fid + self.nodes[int(nid)].vals[int(fval)] = int(child) + + # updating the list of features + self.feats.add(fid) + + # updaing feature domains + self.fdoms[fid].add(int(fval)) + + # adding complex node connections into consideration + for n1 in self.nodes: + conns = collections.defaultdict(lambda: set([])) + for v, n2 in self.nodes[n1].vals.items(): + conns[n2].add(v) + self.nodes[n1].vals = {frozenset(v): n2 for n2, v in conns.items()} + + # simplifying the features and their domains self.feats = sorted(self.feats) - self.feids = {f: i for i, f in enumerate(self.feats)} - self.nof_terms = len(self.terms) - self.nof_nodes -= len(self.terms) + #self.feids = {f: i for i, f in enumerate(self.feats)} + self.fdoms = {f: sorted(self.fdoms[f]) for f in self.fdoms} + + # here we assume all features are present in the tree + # if not, this value will be rewritten by self.parse_mapping() self.nof_feats = len(self.feats) self.paths = collections.defaultdict(lambda: []) self.extract_paths(root=self.root_node, prefix=[]) + def parse_mapping(self, mapfile): + """ + Parse feature-value mapping from a file. + """ + self.fvmap = {} + + lines = mapfile.split('\n') + + if lines[0].startswith('OHE'): + for i in range(int(lines[1])): + feats = lines[i + 2].strip().split(',') + orig, ohe = feats[0], tuple(feats[1:]) + self.ohmap.dir[orig] = tuple(ohe) + for f in ohe: + self.ohmap.opp[f] = orig + + lines = lines[(int(lines[1]) + 2):] + + elif lines[0].startswith('Categorical'): + # skipping the first comment line if necessary + lines = lines[1:] + + elif lines[0].startswith('Ordinal'): + # skipping the first comment line if necessary + lines = lines[1:] + + # number of features + self.nof_feats = int(lines[0].strip()) + self.feids = {} + + for line in lines[1:]: + feat, val, real = line.split() + self.fvmap[tuple([feat, int(val)])] = '{0}{1}'.format(self.feature_names[feat], real) + #if feat not in self.feids: + # self.feids[feat] = len(self.feids) + + #assert len(self.feids) == self.nof_feats + + def convert_to_multiedges(self): + """ + Convert ITI trees with '!=' edges to multi-edges. + """ + # new feature domains + fdoms = collections.defaultdict(lambda: []) + + # tentative mapping relating negative and positive values + nemap = collections.defaultdict(lambda: collections.defaultdict(lambda: [None, None])) + + for fv, tval in self.fvmap.items(): + if '!=' in tval: + nemap[fv[0]][tval.split('=')[1]][0] = fv[1] + else: + fdoms[fv[0]].append(fv[1]) + nemap[fv[0]][tval.split('=')[1]][1] = fv[1] + + # a mapping from negative values to sets + fnmap = collections.defaultdict(lambda: {}) + for f in nemap: + for t, vals in nemap[f].items(): + if vals[0] != None: + fnmap[(f, frozenset({vals[0]}))] = frozenset(set(fdoms[f]).difference({vals[1]})) + + # updating node connections + for n in self.nodes: + vals = {} + for v in self.nodes[n].vals.keys(): + fn = (self.nodes[n].feat, v) + if fn in fnmap: + vals[fnmap[fn]] = self.nodes[n].vals[v] + else: + vals[v] = self.nodes[n].vals[v] + self.nodes[n].vals = vals + + # updating the domains + self.fdoms = fdoms + + # extracting the paths again + self.paths = collections.defaultdict(lambda: []) + self.extract_paths(root=self.root_node, prefix=[]) + def extract_paths(self, root, prefix): """ Traverse the tree and extract explicit paths. """ - - if root in self.terms.keys(): + if root in self.terms: # store the path term = self.terms[root] self.paths[term].append(prefix) else: # select next node - feat, threshold, children_left, children_right = self.nodes[root].feat, self.nodes[root].threshold, self.nodes[root].children_left, self.nodes[root].children_right - self.extract_paths(children_left, prefix + [tuple([feat, "<=" + str(threshold)])]) - self.extract_paths(children_right, prefix + [tuple([feat, ">"+ str(threshold)])]) - - def execute(self, inst): - inst = np.array([inst]) - path = self.tree_.decision_path(inst) - term_id_node = self.tree_.apply(inst) - term_id_node = term_id_node[0] - path = path.indices[path.indptr[0] : path.indptr[0 + 1]] - return path, term_id_node + feat, vals = self.nodes[root].feat, self.nodes[root].vals + for val in vals: + self.extract_paths(vals[val], prefix + [tuple([feat, val])]) + + def execute(self, inst, pathlits=False): + """ + Run the tree and obtain the prediction given an input instance. + """ + root = self.root_node + depth = 0 + path = [] + + # this array is needed if we focus on the path's literals only + visited = [False for f in inst] + + while not root in self.terms: + path.append(root) + feat, vals = self.nodes[root].feat, self.nodes[root].vals + visited[self.feids[feat]] = True + tval = inst[self.feids[feat]][1] + ############### + # assert(len(vals) == 2) + next_node = root + neq = None + for vs, dest in vals.items(): + if tval in vs: + next_node = dest + break + else: + for v in vs: + if '!=' in self.fvmap[(feat, v)]: + neq = dest + break + else: + next_node = neq + # if tval not in vals: + # # go to the False branch (!=) + # for i in vals: + # if "!=" in self.fvmap[(feat,i)]: + # next_node = vals[i] + # break + # else: + # next_node = vals[tval] + + assert (next_node != root) + ############### + root = next_node + depth += 1 + + if pathlits: + # filtering out non-visited literals + for i, v in enumerate(visited): + if not v: + inst[i] = None + + return path, self.terms[root], depth def prepare_sets(self, inst, term): """ Hitting set based encoding of the problem. (currently not incremental -- should be fixed later) """ - sets = [] for t, paths in self.paths.items(): # ignoring the right class - if term in self.terms.keys() and self.terms[term] == t: + if t == term: continue # computing the sets to hit for path in paths: to_hit = [] for item in path: + fv = inst[self.feids[item[0]]] # if the instance disagrees with the path on this item - if ("<=" in item[1] and (inst[item[0]] > np.float32(item[1][2:]))) or (">" in item[1] and (inst[item[0]] <= np.float32(item[1][1:]))) : - if "<=" in item[1] : - fv = tuple([item[0], str(inst[item[0]]), ">" , str(np.float32(item[1][2:]))]) - else : - fv = tuple([item[0], str(inst[item[0]]) , "<=" , str(np.float32(item[1][1:]))]) - to_hit.append(fv) + if fv and not fv[1] in item[1]: + if fv[0] in self.ohmap.opp: + to_hit.append(tuple([self.ohmap.opp[fv[0]], None])) + else: + to_hit.append(fv) - if len(to_hit)>0 : - to_hit = sorted(set(to_hit)) - sets.append(tuple(to_hit)) + to_hit = sorted(set(to_hit)) + sets.append(tuple(to_hit)) + + if self.verbose: + if self.verbose > 1: + print('c trav. path: {0}'.format(path)) + + print('c set to hit: {0}'.format(to_hit)) # returning the set of sets with no duplicates return list(dict.fromkeys(sets)) @@ -190,38 +328,26 @@ class DecisionTree(): """ Compute a given number of explanations. """ - - inst_values = [np.float32(i[1]) for i in inst] - inst_dic = {} - for i in range(len(inst)): - inst_dic[inst[i][0]] = np.float32(inst[i][1]) - path, term = self.execute(inst_values) - #contaiins all the elements for explanation explanation_dic = {} - #instance plotting - explanation_dic["Instance : "] = str(inst_dic) - - #decision path - decision_path_str = "IF : " - for node_id in path: - # continue to the next node if it is a leaf node - if term == node_id: - continue - decision_path_str +="(inst[{feature}] = {value}) {inequality} {threshold}) AND ".format( - feature=self.nodes[node_id].feat, - value=inst_dic[self.nodes[node_id].feat], - inequality="<=" if inst_dic[self.nodes[node_id].feat] <= self.nodes[node_id].threshold else ">" , - threshold=self.nodes[node_id].threshold) + self.feids = {f'f{i}': i for i, f in enumerate(inst)} + inst = [(f'f{i}', int(inst[i][1])) for i,f in enumerate(inst)] + path, term, depth = self.execute(inst, pathlits) - decision_path_str += "THEN " + str(self.terms[term]) + #decision path + decision_path_str = 'IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[inst[self.feids[self.nodes[n].feat]]] for n in path]), term) explanation_dic["Decision path of instance : "] = decision_path_str - explanation_dic["Decision path length : "] = 'Path length is :'+ str(len(path)) + explanation_dic["Decision path length : "] = 'Path length is :'+ str(depth) - # computing the sets to hit - to_hit = self.prepare_sets(inst_dic, term) + if self.ohmap.dir: + f2v = {fv[0]: fv[1] for fv in inst} + # updating fvmap for printing ohe features + for fo, fis in self.ohmap.dir.items(): + self.fvmap[tuple([fo, None])] = '(' + ' AND '.join([self.fvmap[tuple([fi, f2v[fi]])] for fi in fis]) + ')' + # computing the sets to hit + to_hit = self.prepare_sets(inst, term) for type in xtype : if type == "AXp": explanation_dic.update(self.enumerate_abductive(to_hit, enum, solver, htype, term)) @@ -240,12 +366,9 @@ class DecisionTree(): with Hitman(bootstrap_with=to_hit, solver='m22', htype=htype) as hitman: expls = [] for i, expl in enumerate(hitman.enumerate(), 1): - list_expls.append([ p[0] + p[2] + p[3] for p in expl]) - list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join(["(inst[{feature}] = {value}) {inequality} {threshold})".format(feature=p[0], - value=p[1], - inequality=p[2], - threshold=p[3]) - for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))) + list_expls.append([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])]) + list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])]), term)) + expls.append(expl) if i == enum: break @@ -262,7 +385,6 @@ class DecisionTree(): """ Enumerate contrastive explanations. """ - def process_set(done, target): for s in done: if s <= target: @@ -277,10 +399,8 @@ class DecisionTree(): list_expls_str = [] explanation = {} for expl in expls: - list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(["inst[{feature}] {inequality} {threshold})".format(feature=p[0], - inequality="<=" if p[2]==">" else ">", - threshold=p[3]) - for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))) + list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(self.fvmap[p]) for p in sorted(expl, key=lambda p: p[0])]), term)) + explanation["List of contrastive explanation(s)"] = list_expls_str explanation["Number of contrastive explanation(s) : "]=str(len(expls)) explanation["Minimal contrastive explanation : "]= str( min([len(e) for e in expls])) diff --git a/pages/application/DecisionTree/utils/dtviz.py b/pages/application/DecisionTree/utils/dtviz.py index aacf7b646dd1384157116c138a125a069670c772..21ea63920338159b71e45c44a1cd2b772084bc9f 100755 --- a/pages/application/DecisionTree/utils/dtviz.py +++ b/pages/application/DecisionTree/utils/dtviz.py @@ -8,12 +8,16 @@ ## E-mail: alexey.ignatiev@monash.edu ## -import numpy as np +# +#============================================================================== +import getopt import pygraphviz + # #============================================================================== def create_legend(g): legend = g.subgraphs()[-1] + legend.graph_attr.update(size="2,2") legend.add_node("a", style = "invis") legend.add_node("b", style = "invis") legend.add_node("c", style = "invis") @@ -31,55 +35,47 @@ def create_legend(g): edge.attr["style"] = "dashed" +# +#============================================================================== def visualize(dt): """ Visualize a DT with graphviz. """ - g = pygraphviz.AGraph(name='root', rankdir="TB") - g.is_directed() - g.is_strict() - - #g = pygraphviz.AGraph(name = "main", directed=True, strict=True) + g = pygraphviz.AGraph(directed=True, strict=True) g.edge_attr['dir'] = 'forward' + g.graph_attr['rankdir'] = 'TB' # non-terminal nodes for n in dt.nodes: - g.add_node(n, label=str(dt.nodes[n].feat)) + g.add_node(n, label=dt.feature_names[dt.nodes[n].feat]) node = g.get_node(n) node.attr['shape'] = 'circle' node.attr['fontsize'] = 13 # terminal nodes for n in dt.terms: - g.add_node(n, label=str(dt.terms[n])) + g.add_node(n, label=dt.terms[n]) node = g.get_node(n) node.attr['shape'] = 'square' node.attr['fontsize'] = 13 + # transitions for n1 in dt.nodes: - threshold = dt.nodes[n1].threshold - - children_left = dt.nodes[n1].children_left - g.add_edge(n1, children_left) - edge = g.get_edge(n1, children_left) - edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - - children_right = dt.nodes[n1].children_right - g.add_edge(n1, children_right) - edge = g.get_edge(n1, children_right) - edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - - g.add_subgraph(name='legend') - create_legend(g) + for v in dt.nodes[n1].vals: + n2 = dt.nodes[n1].vals[v] + g.add_edge(n1, n2) + edge = g.get_edge(n1, n2) + if len(v) == 1: + edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])] + else: + edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 # saving file g.layout(prog='dot') - return(g.string()) + return(g.to_string()) # #============================================================================== @@ -87,53 +83,50 @@ def visualize_instance(dt, instance): """ Visualize a DT with graphviz and plot the running instance. """ + #path that follows the instance - colored in blue + path, term, depth = dt.execute(instance) + edges_instance = [] + for i in range (len(path)-1) : + edges_instance.append((path[i], path[i+1])) + edges_instance.append((path[-1],"term:"+term)) + g = pygraphviz.AGraph(directed=True, strict=True) g.edge_attr['dir'] = 'forward' g.graph_attr['rankdir'] = 'TB' # non-terminal nodes for n in dt.nodes: - g.add_node(n, label=str(dt.nodes[n].feat)) + g.add_node(n, label=dt.feature_names[dt.nodes[n].feat]) node = g.get_node(n) node.attr['shape'] = 'circle' node.attr['fontsize'] = 13 # terminal nodes for n in dt.terms: - g.add_node(n, label=str(dt.terms[n])) + g.add_node(n, label=dt.terms[n]) node = g.get_node(n) node.attr['shape'] = 'square' node.attr['fontsize'] = 13 - #path that follows the instance - colored in blue - instance = [np.float32(i[1]) for i in instance] - path, term_id_node = dt.execute(instance) - edges_instance = [] - for i in range (len(path)-1) : - edges_instance.append((path[i], path[i+1])) - + # transitions for n1 in dt.nodes: - threshold = dt.nodes[n1].threshold - - children_left = dt.nodes[n1].children_left - g.add_edge(n1, children_left) - edge = g.get_edge(n1, children_left) - edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - #instance path in blue - if ((n1,children_left) in edges_instance): - edge.attr['style'] = 'dashed' - - children_right = dt.nodes[n1].children_right - g.add_edge(n1, children_right) - edge = g.get_edge(n1, children_right) - edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - #instance path in blue - if ((n1,children_right) in edges_instance): - edge.attr['style'] = 'dashed' + for v in dt.nodes[n1].vals: + n2 = dt.nodes[n1].vals[v] + n2_type = g.get_node(n2).attr['shape'] + + g.add_edge(n1, n2) + edge = g.get_edge(n1, n2) + if len(v) == 1: + edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])] + else: + edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) + + #instance path in dashed + if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): + edge.attr['style'] = 'dashed' + + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 g.add_subgraph(name='legend') create_legend(g) @@ -141,66 +134,63 @@ def visualize_instance(dt, instance): # saving file g.layout(prog='dot') return(g.to_string()) -# + #============================================================================== def visualize_expl(dt, instance, expl): """ Visualize a DT with graphviz and plot the running instance. """ + #path that follows the instance - colored in blue + path, term, depth = dt.execute(instance) + edges_instance = [] + for i in range (len(path)-1) : + edges_instance.append((path[i], path[i+1])) + edges_instance.append((path[-1],"term:"+term)) + g = pygraphviz.AGraph(directed=True, strict=True) g.edge_attr['dir'] = 'forward' + g.graph_attr['rankdir'] = 'TB' # non-terminal nodes for n in dt.nodes: - g.add_node(n, label=str(dt.nodes[n].feat)) + g.add_node(n, label=dt.feature_names[dt.nodes[n].feat]) node = g.get_node(n) node.attr['shape'] = 'circle' node.attr['fontsize'] = 13 # terminal nodes for n in dt.terms: - g.add_node(n, label=str(dt.terms[n])) + g.add_node(n, label=dt.terms[n]) node = g.get_node(n) node.attr['shape'] = 'square' node.attr['fontsize'] = 13 - #path that follows the instance - colored in blue - instance = [np.float32(i[1]) for i in instance] - path, term_id_node = dt.execute(instance) - edges_instance = [] - for i in range (len(path)-1) : - edges_instance.append((path[i], path[i+1])) - + # transitions for n1 in dt.nodes: - threshold = dt.nodes[n1].threshold - - children_left = dt.nodes[n1].children_left - g.add_edge(n1, children_left) - edge = g.get_edge(n1, children_left) - edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - #instance path in blue - if ((n1,children_left) in edges_instance): - edge.attr['style'] = 'dashed' - if edge.attr['label'] in expl : - edge.attr['color'] = 'blue' - - children_right = dt.nodes[n1].children_right - g.add_edge(n1, children_right) - edge = g.get_edge(n1, children_right) - edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - #instance path in blue - if ((n1,children_right) in edges_instance): - edge.attr['style'] = 'dashed' - if edge.attr['label'] in expl : - edge.attr['color'] = 'blue' + for v in dt.nodes[n1].vals: + n2 = dt.nodes[n1].vals[v] + n2_type = g.get_node(n2).attr['shape'] + g.add_edge(n1, n2) + edge = g.get_edge(n1, n2) + if len(v) == 1: + edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])] + else: + edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) + + #instance path in dashed + if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): + edge.attr['style'] = 'dashed' + for label in edge.attr['label'].split('\n'): + if label in expl: + edge.attr['color'] = 'blue' + + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 g.add_subgraph(name='legend') create_legend(g) + # saving file g.layout(prog='dot') return(g.to_string()) diff --git a/pages/application/DecisionTree/utils/upload_tree.py b/pages/application/DecisionTree/utils/upload_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..98bf190cfc588cae331fff2caeeb85df06b87645 --- /dev/null +++ b/pages/application/DecisionTree/utils/upload_tree.py @@ -0,0 +1,481 @@ +#!/usr/bin/env python +#-*- coding:utf-8 -*- +## +## tree.py (reuses parts of the code of SHAP) +## +## Created on: Dec 7, 2018 +## Author: Nina Narodytska +## E-mail: narodytska@vmware.com +## + +# +#============================================================================== +from anytree import Node, RenderTree,AsciiStyle +import json +import numpy as np +import math +import six + + +# +#============================================================================== +class xgnode(Node): + def __init__(self, id, parent = None): + Node.__init__(self, id, parent) + self.id = id # node value + self.name = None + self.left_node_id = -1 # left child + self.right_node_id = -1 # right child + + self.feature = -1 + self.threshold = None + self.values = -1 + #iai + self.split = None + + def __str__(self): + pref = ' ' * self.depth + if (len(self.children) == 0): + return (pref+ f"leaf:{self.id} {self.values}") + else: + if(self.name is None): + if (self.threshold is None): + return (pref+ f"({self.id}) f{self.feature}") + else: + return (pref+ f"({self.id}) f{self.feature} = {self.threshold}") + else: + if (self.threshold is None): + return (pref+ f"({self.id}) \"{self.name}\"") + else: + return (pref+ f"({self.id}) \"{self.name}\" = {self.threshold}") + +# +#============================================================================== + +def walk_tree(node): + if (len(node.children) == 0): + # leaf + print(node) + else: + print(node) + walk_tree(node.children[0]) + walk_tree(node.children[1]) + + +# +#============================================================================== +def scores_tree(node, sample): + if (len(node.children) == 0): + # leaf + return node.values + else: + feature_branch = node.feature + sample_value = sample[feature_branch] + assert(sample_value is not None) + if(sample_value < node.threshold): + return scores_tree(node.children[0], sample) + else: + return scores_tree(node.children[1], sample) + + + +# +#============================================================================== +def get_json_tree(model, tool, maxdepth=None): + """ + returns the dtree in JSON format + """ + jt = None + if tool == "DL85": + jt = model.tree_ + elif tool == "IAI": + fname = os.path.splitext(os.path.basename(fname))[0] + dir_name = os.path.join("temp", f"{tool}{maxdepth}") + try: + os.stat(dir_name) + except: + os.makedirs(dir_name) + iai_json = os.path.join(dir_name, fname+'.json') + model.write_json(iai_json) + print(f'load JSON tree from {iai_json} ...') + with open(iai_json) as fp: + jt = json.load(fp) + elif tool == "ITI": + print(f'load JSON tree from {model.json_name} ...') + with open(model.json_name) as fp: + jt = json.load(fp) + #else: + # assert False, 'Unhandled model type: {0}'.format(self.tool) + + return jt + +# +#============================================================================== +class UploadedDecisionTree: + """ A decision tree. + This object provides a common interface to many different types of models. + """ + def __init__(self, model, tool, maxdepth, feature_names=None, nb_classes = 0): + self.tool = tool + self.model = model + self.tree = None + self.depth = None + self.n_nodes = None + json_tree = get_json_tree(self.model, self.tool, maxdepth) + self.tree, self.n_nodes, self.depth = self.build_tree(json_tree, feature_names) + + def print_tree(self): + print("DT model:") + walk_tree(self.tree) + + + def dump(self, fvmap, filename=None, maxdepth=None, feat_names=None): + """ + save the dtree and data map in .dt/.map file + """ + + def walk_tree(node, domains, internal, terminal): + """ + extract internal (non-term) & terminal nodes + """ + if (len(node.children) == 0): # leaf node + terminal.append((node.id, node.values)) + else: + assert (node.children[0].id == node.left_node_id) + assert (node.children[1].id == node.right_node_id) + + f = f"f{node.feature}" + + if self.tool == "DL85": + l,r = (1,0) + internal.append((node.id, f, l, node.children[0].id)) + internal.append((node.id, f, r, node.children[1].id)) + + elif self.tool == "ITI": + #l,r = (0,1) + if len(fvmap[f]) > 2: + n = 0 + for v in fvmap[f]: + if (fvmap[f][v][2] == node.threshold) and \ + (fvmap[f][v][1] == True): + l = v + n = n + 1 + if (fvmap[f][v][2] == node.threshold) and \ + (fvmap[f][v][1] == False): + r = v + n = n + 1 + + assert (n == 2) + + elif (fvmap[f][0][2] == node.threshold): + l,r = (0,1) + else: + assert (fvmap[f][1][2] == node.threshold) + l,r = (1,0) + + internal.append((node.id, f, l, node.children[0].id)) + internal.append((node.id, f, r, node.children[1].id)) + + elif self.tool == "IAI": + left, right = [], [] + for p in fvmap[f]: + if fvmap[f][p][1] == True: + assert (fvmap[f][p][2] in node.split) + if node.split[fvmap[f][p][2]]: + left.append(p) + else: + right.append(p) + + internal.extend([(node.id, f, l, node.children[0].id) for l in left]) + internal.extend([(node.id, f, r, node.children[1].id) for r in right]) + + elif self.tool == 'SKL': + left, right = [], [] + for j in domains[f]: + if np.float32(fvmap[f][j][2]) <= np.float32(node.threshold): + left.append(j) + else: + right.append(j) + + internal.extend([(node.id, f, l, node.children[0].id) for l in left]) + internal.extend([(node.id, f, r, node.children[1].id) for r in right]) + + dom0, dom1 = dict(), dict() + dom0.update(domains) + dom1.update(domains) + dom0[f] = left + dom1[f] = right + + else: + assert False, 'Unhandled model type: {0}'.format(self.tool) + + + internal, terminal = walk_tree(node.children[0], dom0, internal, terminal) + internal, terminal = walk_tree(node.children[1], dom1, internal, terminal) + + return internal, terminal + + domains = {f:[j for j in fvmap[f] if((fvmap[f][j][1]))] for f in fvmap} + internal, terminal = walk_tree(self.tree, domains, [], []) + + + dt = f"{self.n_nodes}\n{self.tree.id}\n" + dt += f"I {' '.join(dict.fromkeys([str(i) for i,_,_,_ in internal]))}\n" + dt +=f"T {' '.join([str(i) for i,_ in terminal ])}\n" + for i,c in terminal: + dt +=f"{i} T {c}\n" + for i,f, j, n in internal: + dt +=f"{i} {f} {j} {n}\n" + + map = "Categorical\n" + map += f"{len(fvmap)}" + for f in fvmap: + for v in fvmap[f]: + if (fvmap[f][v][1] == True): + map += f"\n{f} {v} ={fvmap[f][v][2]}" + if (fvmap[f][v][1] == False) and self.tool == "ITI": + map += f"\n{f} {v} !={fvmap[f][v][2]}" + + + if feat_names is not None: + features_names_mapping = [] + for i,fid in enumerate(feat_names): + feat=f'f{i}' + f = f'T:C,{fid}:{feat},'+",".join([f'{fvmap[feat][v][2]}:{v}' for v in fvmap[feat] if(fvmap[feat][v][1])]) + features_names_mapping.append(f) + + return dt, map, features_names_mapping + + def convert_dt(self, feat_names): + """ + save dtree in .dt format & generate dtree map from the tree + """ + def walk_tree(node, domains, internal, terminal): + """ + extract internal (non-term) & terminal nodes + """ + if not len(node.children): # leaf node + terminal.append((node.id, node.values)) + else: + # internal node + f = f"f{node.feature}" + left, right = [], [] + for j in domains[f]: + if self.intvs[f][j] <= node.threshold: + left.append(j) + else: + right.append(j) + + internal.extend([(node.id, f, l, node.children[0].id) for l in left]) + internal.extend([(node.id, f, r, node.children[1].id) for r in right]) + + dom0, dom1 = dict(), dict() + dom0.update(domains) + dom1.update(domains) + dom0[f] = left + dom1[f] = right + # + internal, terminal = walk_tree(node.children[0], dom0, internal, terminal) + internal, terminal = walk_tree(node.children[1], dom1, internal, terminal) + + return internal, terminal + + assert (self.tool == 'SKL') + domains = {f:[j for j in range(len(self.intvs[f]))] for f in self.intvs} + internal, terminal = walk_tree(self.tree, domains, [], []) + + dt = f"{self.n_nodes}\n{self.tree.id}\n" + dt += f"I {' '.join(dict.fromkeys([str(i) for i,_,_,_ in internal]))}\n" + dt += f"T {' '.join([str(i) for i,_ in terminal ])}" + for i,c in terminal: + dt += f"\n{i} T {c}" + for i,f, j, n in internal: + dt += f"\n{i} {f} {j} {n}" + + map = "Ordinal\n" + map += f"{len(self.intvs)}" + for f in self.intvs: + for j,t in enumerate(self.intvs[f][:-1]): + map += f"\n{f} {j} <={np.round(float(t),4)}" + map += f"\n{f} {j+1} >{np.round(float(t),4)}" + + + if feat_names is not None: + features_names_mapping = [] + for i,fid in enumerate(feat_names): + feat=f'f{i}' + if feat in self.intvs: + f = f'T:O,{fid}:{feat},' + f += ",".join([f'{t}:{j}' for j,t in enumerate(self.intvs[feat])]) + features_names_mapping.append(f) + + return dt, map, features_names_mapping + + + def build_tree(self, json_tree=None, feature_names=None): + + def extract_data(json_node, idx, depth=0, root=None, feature_names=None): + """ + Incremental Tree Inducer / DL8.5 + """ + if (root is None): + node = xgnode(idx) + else: + node = xgnode(idx, parent = root) + + if "feat" in json_node: + + if self.tool == "ITI": #f0, f1, ...,fn + node.feature = json_node["feat"][1:] + else: + node.feature = json_node["feat"] #json DL8.5 + if (feature_names is not None): + node.name = feature_names[node.feature] + + if self.tool == "ITI": + node.threshold = json_node[json_node["feat"]] + + node.left_node_id = idx + 1 + _, idx, d1 = extract_data(json_node['left'], idx+1, depth+1, node, feature_names) + node.right_node_id = idx + 1 + _, idx, d2 = extract_data(json_node['right'], idx+1, depth+1, node, feature_names) + depth = max(d1, d2) + + elif "value" in json_node: + node.values = json_node["value"] + + return node, idx, depth + + + def extract_iai(lnr, json_tree, feature_names = None): + """ + Interpretable AI tree + """ + + json_tree = json_tree['tree_'] + nodes = [] + depth = 0 + for i, json_node in enumerate(json_tree["nodes"]): + if json_node["parent"] == -2: + node = xgnode(json_node["id"]) + else: + root = nodes[json_node["parent"] - 1] + node = xgnode(json_node["id"], parent = root) + + assert (json_node["parent"] > 0) + assert (root.id == json_node["parent"]) + + if json_node["split_type"] == "LEAF": + #node.values = target[json_node["fit"]["class"] - 1] + ##assert json_node["fit"]["probs"][node.values] == 1.0 + node.values = lnr.get_classification_label(node.id) + depth = max(depth, lnr.get_depth(node.id)) + + assert (json_node["lower_child"] == -2 and json_node["upper_child"] == -2) + + elif json_node["split_type"] == "MIXED": + #node.feature = json_node["split_mixed"]["categoric_split"]["feature"] - 1 + #node.left_node_id = json_node["lower_child"] + #node.right_node_id = json_node["upper_child"] + + node.feature = lnr.get_split_feature(node.id) + node.left_node_id = lnr.get_lower_child(node.id) + node.right_node_id = lnr.get_upper_child(node.id) + node.split = lnr.get_split_categories(node.id) + + + assert (json_node["split_mixed"]["categoric_split"]["feature"] > 0) + assert (json_node["lower_child"] > 0) + assert (json_node["upper_child"] > 0) + + else: + assert False, 'Split feature is not \"categoric_split\"' + + nodes.append(node) + + return nodes[0], json_tree["node_count"], depth + + + def extract_skl(tree_, classes_, feature_names=None): + """ + scikit-learn tree + """ + + def get_CART_tree(tree_): + n_nodes = tree_.node_count + children_left = tree_.children_left + children_right = tree_.children_right + #feature = tree_.feature + #threshold = tree_.threshold + #values = tree_.value + node_depth = np.zeros(shape=n_nodes, dtype=np.int64) + + is_leaf = np.zeros(shape=n_nodes, dtype=bool) + stack = [(0, -1)] # seed is the root node id and its parent depth + while len(stack) > 0: + node_id, parent_depth = stack.pop() + node_depth[node_id] = parent_depth + 1 + + # If we have a test node + if (children_left[node_id] != children_right[node_id]): + stack.append((children_left[node_id], parent_depth + 1)) + stack.append((children_right[node_id], parent_depth + 1)) + else: + is_leaf[node_id] = True + + return children_left, children_right, is_leaf, node_depth + + children_left, children_right, is_leaf, node_depth = get_CART_tree(tree_) + feature = tree_.feature + threshold = tree_.threshold + values = tree_.value + m = tree_.node_count + assert (m > 0), "Empty tree" + ## + self.intvs = {f'f{feature[i]}':set([]) for i in range(tree_.node_count) if not is_leaf[i]} + for i in range(tree_.node_count): + if not is_leaf[i]: + self.intvs[f'f{feature[i]}'].add(threshold[i]) + self.intvs = {f: sorted(self.intvs[f])+[math.inf] for f in six.iterkeys(self.intvs)} + + def extract_data(idx, root = None, feature_names = None): + i = idx + assert (i < m), "Error index node" + if (root is None): + node = xgnode(i) + else: + node = xgnode(i, parent = root) + if is_leaf[i]: + node.values = classes_[np.argmax(values[i])] + else: + node.feature = feature[i] + if (feature_names): + node.name = feature_names[feature[i]] + node.threshold = threshold[i] + node.left_node_id = children_left[i] + node.right_node_id = children_right[i] + extract_data(node.left_node_id, node, feature_names) + extract_data(node.right_node_id, node, feature_names) + + return node + + root = extract_data(0, None, feature_names) + return root, tree_.node_count, tree_.max_depth + + + + root, node_count, maxdepth = None, None, None + + if(self.tool == 'SKL'): + if "feature_names_in_" in dir(self.model): + feature_names = self.model.feature_names_in_ + root, node_count, maxdepth = extract_skl(self.model.tree_, self.model.classes_, feature_names) + + if json_tree: + if self.tool == "IAI": + root, node_count, maxdepth = extract_iai(self.model, json_tree, feature_names) + else: + root,_,maxdepth = extract_data(json_tree, 1, 0, None, feature_names) + node_count = json.dumps(json_tree).count('feat') + json.dumps(json_tree).count('value') + + return root, node_count, maxdepth \ No newline at end of file diff --git a/pages/application/application.py b/pages/application/application.py index 16996a6aa067aa7bf5c30bd20b93793995d256c1..db54577aa48a31e87069afd59ebf34ed6a346ea0 100644 --- a/pages/application/application.py +++ b/pages/application/application.py @@ -1,5 +1,6 @@ from dash import dcc, html import dash_bootstrap_components as dbc +import dash_daq as daq from pages.application.DecisionTree.DecisionTreeComponent import DecisionTreeComponent @@ -19,7 +20,13 @@ class Model(): self.ml_model = '' self.pretrained_model = '' - self.typ_data = '' + + self.add_info = False + self.model_info = '' + + self.enum=1 + self.xtype = ['AXp', 'CXp'] + self.solver="g3" self.instance = '' @@ -34,15 +41,35 @@ class Model(): self.component_class = self.dict_components[self.ml_model] self.component_class = globals()[self.component_class] - def update_pretrained_model(self, pretrained_model_update, typ_data): + def update_pretrained_model(self, pretrained_model_update): self.pretrained_model = pretrained_model_update - self.typ_data = typ_data - self.component = self.component_class(self.pretrained_model, self.typ_data) - def update_instance(self, instance, enum, xtype, solver="g3"): + def update_info_needed(self, add_info): + self.add_info = add_info + + def update_pretrained_model_layout(self): + self.component = self.component_class(self.pretrained_model) + + def update_pretrained_model_layout_with_info(self, model_info, model_info_filename): + self.model_info = model_info + self.component = self.component_class(self.pretrained_model, info=self.model_info, type_info=model_info_filename) + + def update_instance(self, instance): self.instance = instance - self.list_expls = self.component.update_with_explicability(self.instance, enum, xtype, solver) + self.list_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) + + def update_enum(self, enum): + self.enum = enum + self.list_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) + def update_xtype(self, xtype): + self.xtype = xtype + self.list_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) + + def update_solver(self, solver): + self.solver = solver + self.list_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) + def update_expl(self, expl): self.expl = expl self.component.draw_explanation(self.instance, expl) @@ -52,11 +79,18 @@ class View(): def __init__(self, model): self.model = model - self.ml_menu_models = dcc.Dropdown(self.model.ml_models, + self.ml_menu_models = html.Div([ + html.Br(), + html.Label("Choose the Machine Learning algorithm :"), + html.Br(), + dcc.Dropdown(self.model.ml_models, id='ml_model_choice', - className="sidebar-dropdown") - + className="dropdown")]) + self.pretrained_model_upload = html.Div([ + html.Hr(), + html.Label("Choose the pretrained model : "), + html.Br(), dcc.Upload( id='ml_pretrained_model_choice', children=html.Div([ @@ -67,7 +101,32 @@ class View(): ), html.Div(id='pretrained_model_filename')]) + self.add_model_info_choice = html.Div([ + html.Hr(), + html.Label("Do you wish to upload more info for your model ? : "), + html.Br(), + daq.BooleanSwitch(id='add_info_model_choice', on=False, color="#000000",)]) + + self.model_info = html.Div(id="choice_info_div", + hidden=True, + children=[ + html.Hr(), + html.Label("Choose the pretrained model dataset (csv) or feature definition file (txt): "), + html.Br(), + dcc.Upload( + id='model_info_choice', + children=html.Div([ + 'Drag and Drop or ', + html.A('Select File') + ]), + className="upload" + ), + html.Div(id='info_filename')]) + self.instance_upload = html.Div([ + html.Hr(), + html.Label("Choose the instance to explain : "), + html.Br(), dcc.Upload( id='ml_instance_choice', children=html.Div([ @@ -78,21 +137,7 @@ class View(): ), html.Div(id='instance_filename')]) - self.sidebar = dcc.Tabs(children=[ - dcc.Tab(label='Basic Parameters', children = [ - html.Br(), - html.Label("Choose the Machine Learning algorithm :"), - html.Br(), - self.ml_menu_models, - html.Hr(), - html.Label("Choose the pretrained model : "), - html.Br(), - self.pretrained_model_upload, - html.Hr(), - html.Label("Choose the instance to explain : "), - html.Br(), - self.instance_upload, - html.Hr(), + self.num_explanation = html.Div([ html.Label("Choose the number of explanations : "), html.Br(), dcc.Input( @@ -100,32 +145,48 @@ class View(): value=1, type="number", placeholder="How many explanations ?", - className="sidebar-dropdown"), - html.Hr(), + className="dropdown"), + html.Hr()]) + + self.type_explanation = html.Div([ html.Label("Choose the kind of explanation : "), html.Br(), dcc.Checklist( id="explanation_type", options={'AXp' : "Abductive Explanation", 'CXp': "Contrastive explanation"}, value = ['AXp', 'CXp'], - className="sidebar-dropdown", - inline=True)], className="sidebar"), - dcc.Tab(label='Advanced Parameters', children = [ - html.Hr(), - html.Label("Choose the SAT solver : "), + className="check-boxes", + inline=True), + html.Hr()]) + + self.solver = html.Div([ html.Label("Choose the SAT solver : "), html.Br(), - dcc.Dropdown(['g3', 'g4', 'lgl', 'mcb', 'mcm', 'mpl', 'm22', 'mc', 'mgh'], 'g3', id='solver_sat') - ], className="sidebar") - ]) + dcc.Dropdown(['g3', 'g4', 'lgl', 'mcb', 'mcm', 'mpl', 'm22', 'mc', 'mgh'], 'g3', id='solver_sat') ]) + + self.sidebar = dcc.Tabs(children=[ + dcc.Tab(label='Basic Parameters', children = [ + self.ml_menu_models, + self.pretrained_model_upload, + self.add_model_info_choice, + self.model_info, + self.instance_upload], className="sidebar"), + dcc.Tab(label='Advanced Parameters', children = [ + html.Br(), + self.num_explanation, + self.type_explanation, + self.solver + ], className="sidebar")]) - - self.expl_choice = dcc.Dropdown(self.model.list_expls, + self.expl_choice = html.Div([html.H5(id = "navigate_label", hidden=True, children="Navigate through the explanations and plot them on the tree : "), + html.Div(id='navigate_dropdown', hidden=True, + children = [dcc.Dropdown(self.model.list_expls, id='expl_choice', - className="dropdown") + className="dropdown")])]) - self.layout = dbc.Row([ dbc.Col([self.sidebar], width=3, class_name="sidebar"), + self.layout = dbc.Row([ dbc.Col([self.sidebar], + width=3, class_name="sidebar"), dbc.Col([dbc.Row(id = "graph", children=[]), - dbc.Row(html.Div([html.H5(id = "navigate_label", hidden=True, children="Navigate through the explanations and plot them on the tree : "), - html.Div(self.expl_choice, id='navigate_dropdown', hidden=True)]))], width=5, class_name="column_graph"), + dbc.Row(self.expl_choice)], + width=5, class_name="column_graph"), dbc.Col(html.Main(id = "explanation", children=[], hidden=True), width=4)]) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index db4cd1fddde22135bee76ca884fda1e72ab6f2c2..4f70ae69cdd7f78304ed576c0745589b7d79e127 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,6 @@ scipy>=1.2.1 dash_bootstrap_components dash_interactive_graphviz python-sat[pblib,aiger] -pygraphviz \ No newline at end of file +pygraphviz==1.9 +anytree==2.8.0 +dash_daq==0.5.0 \ No newline at end of file diff --git a/tests/adult/adult.pkl b/tests/adult/adult.pkl new file mode 100644 index 0000000000000000000000000000000000000000..b18b4d6e77cfc501320ebd5650158dd268530d1d Binary files /dev/null and b/tests/adult/adult.pkl differ diff --git a/tests/adult/adult_2.pkl b/tests/adult/adult_2.pkl new file mode 100644 index 0000000000000000000000000000000000000000..5221943e347e0a7b2b83e0ee33e5c603396536a3 Binary files /dev/null and b/tests/adult/adult_2.pkl differ diff --git a/tests/adult/adult_3.pkl b/tests/adult/adult_3.pkl new file mode 100644 index 0000000000000000000000000000000000000000..b17bc94372fe8d7b6db981e73bf34d6f37b1d07b Binary files /dev/null and b/tests/adult/adult_3.pkl differ diff --git a/tests/adult/adult_data_00000.inst b/tests/adult/adult_data_00000.inst new file mode 100644 index 0000000000000000000000000000000000000000..c72f4e9e0aec1da4c12fdaa3b585d6d96054a30a --- /dev/null +++ b/tests/adult/adult_data_00000.inst @@ -0,0 +1 @@ +f0=1,f1=9,f2=9,f3=7,f4=9,f5=5,f6=7,f7=0,f8=5,f9=3,f10=0,f11=15 \ No newline at end of file diff --git a/tests/iris/decision_tree_classifier_20170212.pkl b/tests/iris/decision_tree_classifier_20170212.pkl new file mode 100644 index 0000000000000000000000000000000000000000..006ccc214168eed2a37f7201ccca66648f1112f7 Binary files /dev/null and b/tests/iris/decision_tree_classifier_20170212.pkl differ diff --git a/tests/iris/iris.csv b/tests/iris/iris.csv new file mode 100644 index 0000000000000000000000000000000000000000..1b9d0294d6d589f667daf28efecbd0cb0276c352 --- /dev/null +++ b/tests/iris/iris.csv @@ -0,0 +1,151 @@ +"sepal.length","sepal.width","petal.length","petal.width","variety" +5.1,3.5,1.4,.2,"Setosa" +4.9,3,1.4,.2,"Setosa" +4.7,3.2,1.3,.2,"Setosa" +4.6,3.1,1.5,.2,"Setosa" +5,3.6,1.4,.2,"Setosa" +5.4,3.9,1.7,.4,"Setosa" +4.6,3.4,1.4,.3,"Setosa" +5,3.4,1.5,.2,"Setosa" +4.4,2.9,1.4,.2,"Setosa" +4.9,3.1,1.5,.1,"Setosa" +5.4,3.7,1.5,.2,"Setosa" +4.8,3.4,1.6,.2,"Setosa" +4.8,3,1.4,.1,"Setosa" +4.3,3,1.1,.1,"Setosa" +5.8,4,1.2,.2,"Setosa" +5.7,4.4,1.5,.4,"Setosa" +5.4,3.9,1.3,.4,"Setosa" +5.1,3.5,1.4,.3,"Setosa" +5.7,3.8,1.7,.3,"Setosa" +5.1,3.8,1.5,.3,"Setosa" +5.4,3.4,1.7,.2,"Setosa" +5.1,3.7,1.5,.4,"Setosa" +4.6,3.6,1,.2,"Setosa" +5.1,3.3,1.7,.5,"Setosa" +4.8,3.4,1.9,.2,"Setosa" +5,3,1.6,.2,"Setosa" +5,3.4,1.6,.4,"Setosa" +5.2,3.5,1.5,.2,"Setosa" +5.2,3.4,1.4,.2,"Setosa" +4.7,3.2,1.6,.2,"Setosa" +4.8,3.1,1.6,.2,"Setosa" +5.4,3.4,1.5,.4,"Setosa" +5.2,4.1,1.5,.1,"Setosa" +5.5,4.2,1.4,.2,"Setosa" +4.9,3.1,1.5,.2,"Setosa" +5,3.2,1.2,.2,"Setosa" +5.5,3.5,1.3,.2,"Setosa" +4.9,3.6,1.4,.1,"Setosa" +4.4,3,1.3,.2,"Setosa" +5.1,3.4,1.5,.2,"Setosa" +5,3.5,1.3,.3,"Setosa" +4.5,2.3,1.3,.3,"Setosa" +4.4,3.2,1.3,.2,"Setosa" +5,3.5,1.6,.6,"Setosa" +5.1,3.8,1.9,.4,"Setosa" +4.8,3,1.4,.3,"Setosa" +5.1,3.8,1.6,.2,"Setosa" +4.6,3.2,1.4,.2,"Setosa" +5.3,3.7,1.5,.2,"Setosa" +5,3.3,1.4,.2,"Setosa" +7,3.2,4.7,1.4,"Versicolor" +6.4,3.2,4.5,1.5,"Versicolor" +6.9,3.1,4.9,1.5,"Versicolor" +5.5,2.3,4,1.3,"Versicolor" +6.5,2.8,4.6,1.5,"Versicolor" +5.7,2.8,4.5,1.3,"Versicolor" +6.3,3.3,4.7,1.6,"Versicolor" +4.9,2.4,3.3,1,"Versicolor" +6.6,2.9,4.6,1.3,"Versicolor" +5.2,2.7,3.9,1.4,"Versicolor" +5,2,3.5,1,"Versicolor" +5.9,3,4.2,1.5,"Versicolor" +6,2.2,4,1,"Versicolor" +6.1,2.9,4.7,1.4,"Versicolor" +5.6,2.9,3.6,1.3,"Versicolor" +6.7,3.1,4.4,1.4,"Versicolor" +5.6,3,4.5,1.5,"Versicolor" +5.8,2.7,4.1,1,"Versicolor" +6.2,2.2,4.5,1.5,"Versicolor" +5.6,2.5,3.9,1.1,"Versicolor" +5.9,3.2,4.8,1.8,"Versicolor" +6.1,2.8,4,1.3,"Versicolor" +6.3,2.5,4.9,1.5,"Versicolor" +6.1,2.8,4.7,1.2,"Versicolor" +6.4,2.9,4.3,1.3,"Versicolor" +6.6,3,4.4,1.4,"Versicolor" +6.8,2.8,4.8,1.4,"Versicolor" +6.7,3,5,1.7,"Versicolor" +6,2.9,4.5,1.5,"Versicolor" +5.7,2.6,3.5,1,"Versicolor" +5.5,2.4,3.8,1.1,"Versicolor" +5.5,2.4,3.7,1,"Versicolor" +5.8,2.7,3.9,1.2,"Versicolor" +6,2.7,5.1,1.6,"Versicolor" +5.4,3,4.5,1.5,"Versicolor" +6,3.4,4.5,1.6,"Versicolor" +6.7,3.1,4.7,1.5,"Versicolor" +6.3,2.3,4.4,1.3,"Versicolor" +5.6,3,4.1,1.3,"Versicolor" +5.5,2.5,4,1.3,"Versicolor" +5.5,2.6,4.4,1.2,"Versicolor" +6.1,3,4.6,1.4,"Versicolor" +5.8,2.6,4,1.2,"Versicolor" +5,2.3,3.3,1,"Versicolor" +5.6,2.7,4.2,1.3,"Versicolor" +5.7,3,4.2,1.2,"Versicolor" +5.7,2.9,4.2,1.3,"Versicolor" +6.2,2.9,4.3,1.3,"Versicolor" +5.1,2.5,3,1.1,"Versicolor" +5.7,2.8,4.1,1.3,"Versicolor" +6.3,3.3,6,2.5,"Virginica" +5.8,2.7,5.1,1.9,"Virginica" +7.1,3,5.9,2.1,"Virginica" +6.3,2.9,5.6,1.8,"Virginica" +6.5,3,5.8,2.2,"Virginica" +7.6,3,6.6,2.1,"Virginica" +4.9,2.5,4.5,1.7,"Virginica" +7.3,2.9,6.3,1.8,"Virginica" +6.7,2.5,5.8,1.8,"Virginica" +7.2,3.6,6.1,2.5,"Virginica" +6.5,3.2,5.1,2,"Virginica" +6.4,2.7,5.3,1.9,"Virginica" +6.8,3,5.5,2.1,"Virginica" +5.7,2.5,5,2,"Virginica" +5.8,2.8,5.1,2.4,"Virginica" +6.4,3.2,5.3,2.3,"Virginica" +6.5,3,5.5,1.8,"Virginica" +7.7,3.8,6.7,2.2,"Virginica" +7.7,2.6,6.9,2.3,"Virginica" +6,2.2,5,1.5,"Virginica" +6.9,3.2,5.7,2.3,"Virginica" +5.6,2.8,4.9,2,"Virginica" +7.7,2.8,6.7,2,"Virginica" +6.3,2.7,4.9,1.8,"Virginica" +6.7,3.3,5.7,2.1,"Virginica" +7.2,3.2,6,1.8,"Virginica" +6.2,2.8,4.8,1.8,"Virginica" +6.1,3,4.9,1.8,"Virginica" +6.4,2.8,5.6,2.1,"Virginica" +7.2,3,5.8,1.6,"Virginica" +7.4,2.8,6.1,1.9,"Virginica" +7.9,3.8,6.4,2,"Virginica" +6.4,2.8,5.6,2.2,"Virginica" +6.3,2.8,5.1,1.5,"Virginica" +6.1,2.6,5.6,1.4,"Virginica" +7.7,3,6.1,2.3,"Virginica" +6.3,3.4,5.6,2.4,"Virginica" +6.4,3.1,5.5,1.8,"Virginica" +6,3,4.8,1.8,"Virginica" +6.9,3.1,5.4,2.1,"Virginica" +6.7,3.1,5.6,2.4,"Virginica" +6.9,3.1,5.1,2.3,"Virginica" +5.8,2.7,5.1,1.9,"Virginica" +6.8,3.2,5.9,2.3,"Virginica" +6.7,3.3,5.7,2.5,"Virginica" +6.7,3,5.2,2.3,"Virginica" +6.3,2.5,5,1.9,"Virginica" +6.5,3,5.2,2,"Virginica" +6.2,3.4,5.4,2.3,"Virginica" +5.9,3,5.1,1.8,"Virginica" \ No newline at end of file diff --git a/tests/iris/iris.pkl b/tests/iris/iris.pkl new file mode 100644 index 0000000000000000000000000000000000000000..cf2d521581ec57be284bb87b75316940c16fdcb7 Binary files /dev/null and b/tests/iris/iris.pkl differ diff --git a/tests/iris/iris.txt b/tests/iris/iris.txt new file mode 100644 index 0000000000000000000000000000000000000000..d356f5d1d011df170da9dacee5ac4b2d6bfc844b --- /dev/null +++ b/tests/iris/iris.txt @@ -0,0 +1,4 @@ +sepal.length,Categorical,7.6,6.8,7.1,4.9,4.4,6.2,6,7.3,5.9,7.4,5.2,5.6,4.8,6.5,5.5,4.6,6.6,6.4,7,4.5,7.2,5.1,5.8,5.3,6.9,6.1,6.7,4.7,7.7,6.3,5.7,7.9,5.4,4.3,5 +sepal.width,Categorical,4.2,4.4,3.1,2.4,2.9,2,3.8,4.1,4,3.2,2.7,3.3,2.2,2.5,2.3,3.6,3.5,3.9,2.8,2.6,3.7,3,3.4 +petal.length,Categorical,4.2,4.9,4.4,6,5.9,5.2,5.6,4.8,1,5.5,4.6,6.6,1.1,3.8,1.5,6.4,4.1,4,4.5,1.6,3.3,1.4,5.1,1.7,5.8,3.5,3.6,5.3,1.9,6.9,6.1,6.7,4.7,3.9,1.2,1.3,6.3,5.7,3.7,5.4,3,4.3,5 +petal.width,Categorical,2.4,.2,1,2,1.1,1.5,.6,.5,2.2,.3,1.6,1.4,2.5,1.7,2.3,1.8,2.1,1.9,1.2,1.3,.1,.4 \ No newline at end of file diff --git a/tests/iris/iris01.json b/tests/iris/iris01.json new file mode 100644 index 0000000000000000000000000000000000000000..59c0840b5597ccb4347c3c2041707f0ae1d11aa0 --- /dev/null +++ b/tests/iris/iris01.json @@ -0,0 +1,4 @@ +{"sepal.length":4.9, +"sepal.width":3, +"petal.length":1.4, +"petal.width":0.2} diff --git a/tests/iris/iris2.pkl b/tests/iris/iris2.pkl new file mode 100644 index 0000000000000000000000000000000000000000..cf2d521581ec57be284bb87b75316940c16fdcb7 Binary files /dev/null and b/tests/iris/iris2.pkl differ diff --git a/tests/iris/iris_00000.txt b/tests/iris/iris_00000.txt new file mode 100644 index 0000000000000000000000000000000000000000..108004d0deba35123ce4bdbcb9fa37930d6164ba --- /dev/null +++ b/tests/iris/iris_00000.txt @@ -0,0 +1 @@ +sepal.length=4.3,sepal.width=2.0,petal.length=1.0,petal.width=0.1 \ No newline at end of file diff --git a/tests/zoo/inst/zoo_00.inst b/tests/zoo/inst/zoo_00.inst new file mode 100644 index 0000000000000000000000000000000000000000..8d6abc5618927fb2ddf49400932809448ebef26c --- /dev/null +++ b/tests/zoo/inst/zoo_00.inst @@ -0,0 +1 @@ +f0=1,f1=0,f1=0,f1=1,f1=0,f1=0,f1=1,f1=1,f1=1,f1=1,f1=0,f1=0,f1=4,f1=0,f1=0,f1=1 \ No newline at end of file diff --git a/tests/zoo/inst/zoo_01.inst b/tests/zoo/inst/zoo_01.inst new file mode 100644 index 0000000000000000000000000000000000000000..4a6df5103d348fbf79e75654cb20f9594d47e842 --- /dev/null +++ b/tests/zoo/inst/zoo_01.inst @@ -0,0 +1 @@ +1,0,0,1,0,0,0,1,1,1,0,0,4,1,0,1 \ No newline at end of file diff --git a/tests/zoo/inst/zoo_02.inst b/tests/zoo/inst/zoo_02.inst new file mode 100644 index 0000000000000000000000000000000000000000..d72c0ae09ad51c4f6ec073b66e773beb18515be4 --- /dev/null +++ b/tests/zoo/inst/zoo_02.inst @@ -0,0 +1 @@ +0,0,1,0,0,1,1,1,1,0,0,1,0,1,0,0 \ No newline at end of file diff --git a/tests/zoo/inst/zoo_11.inst b/tests/zoo/inst/zoo_11.inst new file mode 100644 index 0000000000000000000000000000000000000000..0fa9d7c8a871929cbe518a7be16966fc2f7b9a6d --- /dev/null +++ b/tests/zoo/inst/zoo_11.inst @@ -0,0 +1 @@ +0,1,1,0,1,0,0,0,1,1,0,0,2,1,1,0 \ No newline at end of file diff --git a/tests/zoo/zoo.csv b/tests/zoo/zoo.csv new file mode 100644 index 0000000000000000000000000000000000000000..7eb9774cdc5f452baac8ed8d4ea0efe83a466386 --- /dev/null +++ b/tests/zoo/zoo.csv @@ -0,0 +1,102 @@ +hair,feathers,eggs,milk,airborne,aquatic,predator,toothed,backbone,breathes,venomous,fins,legs,tail,domestic,catsize,class_type +1,0,0,1,0,0,1,1,1,1,0,0,4,0,0,1,mammal +1,0,0,1,0,0,0,1,1,1,0,0,4,1,0,1,mammal +0,0,1,0,0,1,1,1,1,0,0,1,0,1,0,0,fish +1,0,0,1,0,0,1,1,1,1,0,0,4,0,0,1,mammal +1,0,0,1,0,0,1,1,1,1,0,0,4,1,0,1,mammal +1,0,0,1,0,0,0,1,1,1,0,0,4,1,0,1,mammal +1,0,0,1,0,0,0,1,1,1,0,0,4,1,1,1,mammal +0,0,1,0,0,1,0,1,1,0,0,1,0,1,1,0,fish +0,0,1,0,0,1,1,1,1,0,0,1,0,1,0,0,fish +1,0,0,1,0,0,0,1,1,1,0,0,4,0,1,0,mammal +1,0,0,1,0,0,1,1,1,1,0,0,4,1,0,1,mammal +0,1,1,0,1,0,0,0,1,1,0,0,2,1,1,0,bird +0,0,1,0,0,1,1,1,1,0,0,1,0,1,0,0,fish +0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,invertebrate +0,0,1,0,0,1,1,0,0,0,0,0,4,0,0,0,invertebrate +0,0,1,0,0,1,1,0,0,0,0,0,6,0,0,0,invertebrate +0,1,1,0,1,0,1,0,1,1,0,0,2,1,0,0,bird +1,0,0,1,0,0,0,1,1,1,0,0,4,1,0,1,mammal +0,0,1,0,0,1,1,1,1,0,0,1,0,1,0,1,fish +0,0,0,1,0,1,1,1,1,1,0,1,0,1,0,1,mammal +0,1,1,0,1,0,0,0,1,1,0,0,2,1,1,0,bird +0,1,1,0,1,1,0,0,1,1,0,0,2,1,0,0,bird +1,0,0,1,0,0,0,1,1,1,0,0,4,1,0,1,mammal +0,1,1,0,1,0,0,0,1,1,0,0,2,1,0,1,bird +0,0,1,0,0,0,0,0,0,1,0,0,6,0,0,0,bug +0,0,1,0,0,1,1,1,1,1,0,0,4,0,0,0,amphibian +0,0,1,0,0,1,1,1,1,1,1,0,4,0,0,0,amphibian +1,0,0,1,1,0,0,1,1,1,0,0,2,1,0,0,mammal +1,0,0,1,0,0,0,1,1,1,0,0,4,1,0,1,mammal +1,0,0,1,0,0,1,1,1,1,0,0,2,0,1,1,mammal +0,0,1,0,1,0,0,0,0,1,0,0,6,0,0,0,bug +1,0,0,1,0,0,0,1,1,1,0,0,4,1,1,1,mammal +1,0,0,1,0,0,0,1,1,1,0,0,2,0,0,1,mammal +0,1,1,0,1,1,1,0,1,1,0,0,2,1,0,0,bird +0,0,1,0,0,1,0,1,1,0,0,1,0,1,0,0,fish +1,0,0,1,0,0,0,1,1,1,0,0,4,1,1,0,mammal +1,0,0,1,0,0,0,1,1,1,0,0,4,1,0,0,mammal +0,1,1,0,1,0,1,0,1,1,0,0,2,1,0,0,bird +0,0,1,0,0,1,1,1,1,0,0,1,0,1,0,0,fish +1,0,1,0,1,0,0,0,0,1,1,0,6,0,1,0,bug +1,0,1,0,1,0,0,0,0,1,0,0,6,0,0,0,bug +0,1,1,0,0,0,1,0,1,1,0,0,2,1,0,0,bird +0,0,1,0,1,0,1,0,0,1,0,0,6,0,0,0,bug +0,1,1,0,1,0,0,0,1,1,0,0,2,1,0,0,bird +1,0,0,1,0,0,1,1,1,1,0,0,4,1,0,1,mammal +1,0,0,1,0,0,1,1,1,1,0,0,4,1,0,1,mammal +0,0,1,0,0,1,1,0,0,0,0,0,6,0,0,0,invertebrate +1,0,0,1,0,0,1,1,1,1,0,0,4,1,0,1,mammal +1,0,0,1,0,1,1,1,1,1,0,0,4,1,0,1,mammal +1,0,0,1,0,0,1,1,1,1,0,0,4,1,0,0,mammal +1,0,0,1,0,0,1,1,1,1,0,0,4,1,0,1,mammal +1,0,1,0,1,0,0,0,0,1,0,0,6,0,0,0,bug +0,0,1,0,0,1,1,1,1,1,0,0,4,1,0,0,amphibian +0,0,1,0,0,1,1,0,0,0,0,0,8,0,0,1,invertebrate +1,0,0,1,0,0,1,1,1,1,0,0,4,1,0,0,mammal +1,0,0,1,0,0,0,1,1,1,0,0,4,1,0,1,mammal +0,1,1,0,0,0,0,0,1,1,0,0,2,1,0,1,bird +0,1,1,0,1,0,0,0,1,1,0,0,2,1,1,0,bird +0,1,1,0,0,1,1,0,1,1,0,0,2,1,0,1,bird +0,1,1,0,1,0,0,0,1,1,0,0,2,1,0,0,bird +0,0,1,0,0,1,1,1,1,0,0,1,0,1,0,1,fish +0,0,1,0,0,1,1,1,1,0,0,1,0,1,0,0,fish +0,0,1,0,0,0,1,1,1,1,1,0,0,1,0,0,reptile +1,0,1,1,0,1,1,0,1,1,0,0,4,1,0,1,mammal +1,0,0,1,0,0,1,1,1,1,0,0,4,1,0,1,mammal +1,0,0,1,0,0,0,1,1,1,0,0,4,1,1,1,mammal +0,0,0,1,0,1,1,1,1,1,0,1,0,1,0,1,mammal +1,0,0,1,0,0,1,1,1,1,0,0,4,1,0,1,mammal +1,0,0,1,0,0,1,1,1,1,0,0,4,1,1,1,mammal +1,0,0,1,0,0,1,1,1,1,0,0,4,1,0,1,mammal +1,0,0,1,0,0,0,1,1,1,0,0,4,1,1,1,mammal +0,1,1,0,0,0,1,0,1,1,0,0,2,1,0,1,bird +0,0,0,0,0,0,1,0,0,1,1,0,8,1,0,0,invertebrate +0,0,1,0,0,1,0,1,1,0,0,1,0,1,0,0,fish +1,0,0,1,0,1,1,1,1,1,0,1,0,0,0,1,mammal +1,0,0,1,0,1,1,1,1,1,0,1,2,1,0,1,mammal +0,0,0,0,0,1,1,1,1,0,1,0,0,1,0,0,reptile +0,0,1,0,0,1,1,0,0,0,1,0,0,0,0,0,invertebrate +0,1,1,0,1,1,1,0,1,1,0,0,2,1,0,0,bird +0,1,1,0,1,1,1,0,1,1,0,0,2,1,0,0,bird +0,0,1,0,0,0,1,1,1,1,0,0,0,1,0,0,reptile +0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,invertebrate +0,0,1,0,0,1,0,1,1,0,0,1,0,1,0,0,fish +0,1,1,0,1,0,0,0,1,1,0,0,2,1,0,0,bird +1,0,0,1,0,0,0,1,1,1,0,0,2,1,0,0,mammal +0,0,1,0,0,1,1,0,0,0,0,0,5,0,0,0,invertebrate +0,0,1,0,0,1,1,1,1,0,1,1,0,1,0,1,fish +0,1,1,0,1,1,0,0,1,1,0,0,2,1,0,1,bird +0,0,1,0,0,0,0,0,0,1,0,0,6,0,0,0,bug +0,0,1,0,0,1,0,1,1,1,0,0,4,0,0,0,amphibian +0,0,1,0,0,0,0,0,1,1,0,0,4,1,0,1,reptile +0,0,1,0,0,0,1,1,1,1,0,0,4,1,0,0,reptile +0,0,1,0,0,1,1,1,1,0,0,1,0,1,0,1,fish +1,0,0,1,1,0,0,1,1,1,0,0,2,1,0,0,mammal +1,0,0,1,0,0,0,1,1,1,0,0,4,1,0,0,mammal +0,1,1,0,1,0,1,0,1,1,0,0,2,1,0,1,bird +1,0,0,1,0,0,0,1,1,1,0,0,2,1,0,1,mammal +1,0,1,0,1,0,0,0,0,1,1,0,6,0,0,0,bug +1,0,0,1,0,0,1,1,1,1,0,0,4,1,0,1,mammal +0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,invertebrate +0,1,1,0,1,0,0,0,1,1,0,0,2,1,0,0,bird \ No newline at end of file diff --git a/tests/zoo/zoo.dt b/tests/zoo/zoo.dt new file mode 100644 index 0000000000000000000000000000000000000000..a1ee6427913de0385b71077c93d9889e0479c78d --- /dev/null +++ b/tests/zoo/zoo.dt @@ -0,0 +1,85 @@ +41 +1 +I 1 2 3 6 9 10 12 14 17 18 20 21 22 24 28 30 33 35 37 38 +T 4 5 7 8 11 13 15 16 19 23 25 26 27 29 31 32 34 36 39 40 41 +4 T bug +5 T mammal +7 T bird +8 T bug +11 T mammal +13 T fish +15 T bird +16 T reptile +19 T mammal +23 T reptile +25 T reptile +26 T bird +27 T fish +29 T amphibian +31 T amphibian +32 T reptile +34 T invertebrate +36 T invertebrate +39 T invertebrate +40 T bug +41 T invertebrate +1 f4 0 2 +1 f4 1 9 +2 f0 0 3 +2 f0 1 6 +3 f10 0 4 +3 f10 1 5 +6 f12 1 7 +6 f12 5 7 +6 f12 9 7 +6 f12 11 7 +6 f12 3 8 +6 f12 7 8 +9 f15 0 10 +9 f15 1 17 +10 f3 0 11 +10 f3 1 12 +12 f11 0 13 +12 f11 1 14 +14 f12 1 15 +14 f12 9 15 +14 f12 3 16 +14 f12 5 16 +14 f12 7 16 +14 f12 11 16 +17 f8 0 18 +17 f8 1 33 +18 f0 0 19 +18 f0 1 20 +20 f12 1 21 +20 f12 9 21 +20 f12 3 28 +20 f12 5 28 +20 f12 7 28 +20 f12 11 28 +21 f6 0 22 +21 f6 1 27 +22 f10 0 23 +22 f10 1 24 +24 f12 9 25 +24 f12 1 26 +24 f12 3 26 +24 f12 5 26 +24 f12 7 26 +24 f12 11 26 +28 f10 0 29 +28 f10 1 30 +30 f5 0 31 +30 f5 1 32 +33 f5 0 34 +33 f5 1 35 +35 f13 0 36 +35 f13 1 37 +37 f9 0 38 +37 f9 1 41 +38 f12 1 39 +38 f12 5 39 +38 f12 9 39 +38 f12 11 39 +38 f12 3 40 +38 f12 7 40 diff --git a/tests/zoo/zoo.json b/tests/zoo/zoo.json new file mode 100644 index 0000000000000000000000000000000000000000..506568b6a24d9f5066f1673d69e24a45d1109299 --- /dev/null +++ b/tests/zoo/zoo.json @@ -0,0 +1,18 @@ +{ +"hair":1, +"feathers":0, +"eggs":0, +"milk":1, +"airborne":0, +"aquatic":0, +"predator":0, +"toothed":1, +"backbone":1, +"breathes":1, +"venomous":0, +"fins":0, +"legs":6, +"tail":1, +"domestic":0, +"catsize":1 +} diff --git a/tests/zoo/zoo.map b/tests/zoo/zoo.map new file mode 100644 index 0000000000000000000000000000000000000000..89838b19a6ef9d3425d66ce4c6ec27f5c3981fdd --- /dev/null +++ b/tests/zoo/zoo.map @@ -0,0 +1,38 @@ +Categorical +16 +f0 0 =1 +f0 1 =0 +f1 0 =1 +f1 1 =0 +f2 0 =1 +f2 1 =0 +f3 0 =1 +f3 1 =0 +f4 0 =1 +f4 1 =0 +f5 0 =1 +f5 1 =0 +f6 0 =1 +f6 1 =0 +f7 0 =1 +f7 1 =0 +f8 0 =1 +f8 1 =0 +f9 0 =1 +f9 1 =0 +f10 0 =1 +f10 1 =0 +f11 0 =1 +f11 1 =0 +f12 1 =2 +f12 3 =6 +f12 5 =5 +f12 7 =8 +f12 9 =0 +f12 11 =4 +f13 0 =1 +f13 1 =0 +f14 0 =1 +f14 1 =0 +f15 0 =1 +f15 1 =0 diff --git a/tests/zoo/zoo.pkl b/tests/zoo/zoo.pkl new file mode 100644 index 0000000000000000000000000000000000000000..edfd6540dd861d1260292b0bce1efe6d6a809461 Binary files /dev/null and b/tests/zoo/zoo.pkl differ diff --git a/utils.py b/utils.py index a3afff7ac45293e228d2a82e510de33d05ec77b5..42b30c99af3af1b8463b3aa8e29f1834e974e0b3 100644 --- a/utils.py +++ b/utils.py @@ -1,6 +1,8 @@ import base64 import io import pickle +import joblib +import json import numpy as np from dash import html @@ -11,28 +13,47 @@ def parse_contents_graph(contents, filename): decoded = base64.b64decode(content_string) try: if '.pkl' in filename: - data = pickle.load(io.BytesIO(decoded)) - typ = 'pkl' + data = joblib.load(io.BytesIO(decoded)) except Exception as e: print(e) return html.Div([ 'There was an error processing this file.' ]) - return data, typ + return data + +def parse_contents_data(contents, filename): + content_type, content_string = contents.split(',') + decoded = base64.b64decode(content_string) + try: + if '.csv' in filename: + data = decoded.decode('utf-8') + if '.txt' in filename: + data = decoded.decode('utf-8') + except Exception as e: + print(e) + return html.Div([ + 'There was an error processing this file.' + ]) + + return data def parse_contents_instance(contents, filename): content_type, content_string = contents.split(',') decoded = base64.b64decode(content_string) try: - if 'csv' in filename: + if '.csv' in filename: data = decoded.decode('utf-8') - elif 'txt' in filename: + data = str(data).strip().split(',') + data = list(map(lambda i: tuple([i[0], np.float32(i[1])]), [i.split('=') for i in data])) + elif '.txt' in filename: data = decoded.decode('utf-8') - else : + data = str(data).strip().split(',') + data = list(map(lambda i: tuple([i[0], np.float32(i[1])]), [i.split('=') for i in data])) + elif '.json' in filename: data = decoded.decode('utf-8') - data = str(data).strip().split(',') - data = list(map(lambda i: tuple([i[0], np.float32(i[1])]), [i.split('=') for i in data])) + data = json.loads(data) + data = list(tuple(data.items())) except Exception as e: print(e) return html.Div([