diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4a1ff58c3a633f51533d402e7316cd7819fb711e --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +.idea +__pycache__ +pages/application/DecisionTree/utils/__pycache__ +pages/application/DecisionTree/__pycache__ +pages/application/__pycache__ +decision_tree_classifier_20170212.pkl \ No newline at end of file diff --git a/app.py b/app.py index 41455d2f594c8664b14f6dc32b539b96a2f3cac3..e605e68bf0922f4f6ca56198ba817b533d2ddd50 100644 --- a/app.py +++ b/app.py @@ -1,19 +1,15 @@ # Run this app with `python app.py` and # visit http://127.0.0.1:8050/ in your web browser. +import json + import dash -import json -from dash import dcc -from dash import html -from dash import dcc, html, Input, Output, State import dash_bootstrap_components as dbc +import pandas as pd +from dash import Input, Output, State, dcc, html from dash.exceptions import PreventUpdate from pages.application.layout_application import Model, View -from utils import parse_contents_instance, parse_contents_tree, extract_data - -import base64 -import io -import pandas as pd +from utils import extract_data, parse_contents_instance, parse_contents_tree ''' Loading data @@ -44,7 +40,7 @@ app.layout = html.Div([ Callback for the app ''' @app.callback( - Output('tree_filename', 'children'), + Output('dataset_filename', 'children'), Output('instance_filename', 'children'), Output('graph', 'children'), Output('explanation', 'children'), @@ -64,14 +60,14 @@ def update_ml_type(value_ml_model, dataset_contents, instance_contents, enum, xt ihm_id = ctx.triggered[0]['prop_id'].split('.')[0] if ihm_id == 'ml_model_choice' : model.update_ml_model(value_ml_model) - return dataset_filename, instance_filename, "", "" + return "", "", "", "" elif ihm_id == 'ml_dataset_choice': if value_ml_model == None : raise PreventUpdate - tree = parse_contents_tree(dataset_contents, dataset_filename) - model.update_dataset(tree) - return dataset_filename, instance_filename, model.component.network, "" + tree, typ = parse_contents_tree(dataset_contents, dataset_filename) + model.update_dataset(tree, typ) + return dataset_filename, "", model.component.network, "" elif ihm_id == 'ml_instance_choice' : if value_ml_model == None or dataset_contents == None or enum == None or xtype==None: @@ -107,4 +103,4 @@ def update_ml_type(value_ml_model, dataset_contents, instance_contents, enum, xt Launching app ''' if __name__ == '__main__': - app.run_server(debug=True) \ No newline at end of file + app.run_server(debug=True) diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py index 15f1d2c30f88a199b49f8e79146f82a7f2df4551..6ecb24cc60ae38cd71145b53532efda8f1f0afdb 100644 --- a/pages/application/DecisionTree/DecisionTreeComponent.py +++ b/pages/application/DecisionTree/DecisionTreeComponent.py @@ -8,9 +8,12 @@ from os import path class DecisionTreeComponent(): - def __init__(self, tree): + def __init__(self, tree, typ_data): - self.dt = DecisionTree(from_string = tree) + if typ_data == "dt" : + self.dt = DecisionTree(from_dt = tree) + elif typ_data == "pkl" : + self.dt = DecisionTree(from_pickle = tree) dot_source = visualize(self.dt) diff --git a/pages/application/DecisionTree/utils/dtree.py b/pages/application/DecisionTree/utils/dtree.py index db322142bb6ffac20ecf660b7c5ac5290d16072c..ad48a393c16c041dc26fe059eb7b4ae5101280b3 100644 --- a/pages/application/DecisionTree/utils/dtree.py +++ b/pages/application/DecisionTree/utils/dtree.py @@ -22,6 +22,8 @@ try: # for Python2 from cStringIO import StringIO except ImportError: # for Python3 from io import StringIO +from sklearn.tree import _tree +import numpy as np # #============================================================================== @@ -30,13 +32,16 @@ class Node(): Node class. """ - def __init__(self, feat='', vals=[]): + def __init__(self, feat='', vals=[], threshold=None): """ Constructor. """ self.feat = feat - self.vals = vals + if threshold is not None : + self.threshold = threshold + else : + self.vals = vals # @@ -46,14 +51,13 @@ class DecisionTree(): Simple decision tree class. """ - def __init__(self, from_file=None, from_fp=None, from_string=None, - mapfile=None, medges=False, verbose=0): + def __init__(self, from_dt=None, from_pickle=None, + mapfile=None, verbose=0): """ Constructor. """ self.verbose = verbose - self.medges = medges and mapfile != None self.nof_nodes = 0 self.nof_terms = 0 @@ -70,12 +74,10 @@ class DecisionTree(): OHEMap = collections.namedtuple('OHEMap', ['dir', 'opp']) self.ohmap = OHEMap(dir={}, opp={}) - if from_file: - self.from_file(from_file) - elif from_string: - self.from_string(from_string) - elif from_fp: - self.from_fp(from_fp) + if from_dt: + self.from_dt(from_dt) + elif from_pickle: + self.from_pickle_file(from_pickle) if mapfile: self.parse_mapping(mapfile) @@ -84,15 +86,76 @@ class DecisionTree(): for v in self.fdoms[f]: self.fvmap[tuple([f, v])] = '{0}={1}'.format(f, v) - if self.medges: - self.convert_to_multiedges() + #problem de feature names et problem de vals dans node + def from_pickle_file(self, tree): + tree_ = tree.tree_ + try: + feature_names = tree.feature_names_in_ + except: + print("You did not dump the model with the features names") + feature_names = [str(i) for i in range(tree.n_features_in_)] - def from_fp(self, fp): + class_names = tree.classes_ + + self.nodes = collections.defaultdict(lambda: Node(feat='', vals={})) + self.terms={} + self.nof_nodes = tree_.node_count + self.nof_terms = 0 + self.root_node = 0 + + feature_name = [ + feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" + for i in tree_.feature] + + def recurse(feats, fdoms, node): + if tree_.feature[node] != _tree.TREE_UNDEFINED: + name = feature_name[node] + val = tree_.threshold[node] + + #faire une boucle for des vals ? + self.nodes[int(node)].feat = name + self.nodes[int(node)].vals[int(np.round(val,4))] = int(tree_.children_left[node]) + + self.nodes[int(node)].feat = name + self.nodes[int(node)].vals[int(4854)] = int(tree_.children_right[node]) + + feats.add(name) + fdoms[name].add(int(np.round(val,4))) + feats, fdoms = recurse(feats, fdoms, tree_.children_left[node]) + fdoms[name].add(4854) + feats, fdoms = recurse(feats, fdoms, tree_.children_right[node]) + + else: + self.terms[node] = class_names[np.argmax(tree_.value[node])] + print("leaf {}".format(tree_.value[node])) + + return feats, fdoms + + self.feats, self.fdoms = recurse(set([]), collections.defaultdict(lambda: set([])), self.root_node) + + for parent in self.nodes: + conns = collections.defaultdict(lambda: set([])) + for val, child in self.nodes[parent].vals.items(): + conns[child].add(val) + self.nodes[parent].vals = {frozenset(val): child for child, val in conns.items()} + + self.feats = sorted(self.feats) + self.feids = {f: i for i, f in enumerate(self.feats)} + self.fdoms = {f: sorted(self.fdoms[f]) for f in self.fdoms} + self.nof_terms = len(self.terms) + self.nof_feats = len(self.feats) + + self.paths = collections.defaultdict(lambda: []) + self.extract_paths(root=self.root_node, prefix=[]) + + def from_dt(self, data): """ Get the tree from a file pointer. """ - lines = fp.readlines() + contents = StringIO(data) + + lines = contents.readlines() # filtering out comment lines (those that start with '#') lines = list(filter(lambda l: not l.startswith('#'), lines)) @@ -149,13 +212,6 @@ class DecisionTree(): self.paths = collections.defaultdict(lambda: []) self.extract_paths(root=self.root_node, prefix=[]) - def from_string(self, string): - """ - Get a DT from a string. - """ - - self.from_fp(StringIO(string)) - def extract_paths(self, root, prefix): """ Traverse the tree and extract explicit paths. diff --git a/pages/application/layout_application.py b/pages/application/layout_application.py index 59efb81b1a4f12606d9dabfd806adb5056ceba51..78c5c019436a0e3ac2a5e74ea59434b872b586f6 100644 --- a/pages/application/layout_application.py +++ b/pages/application/layout_application.py @@ -20,6 +20,7 @@ class Model(): self.ml_model = '' self.dataset = '' + self.typ_data = '' self.instance = '' @@ -29,11 +30,12 @@ class Model(): def update_ml_model(self, ml_model_update): self.ml_model = ml_model_update self.component_class = self.dict_components[self.ml_model] + self.component_class = globals()[self.component_class] - def update_dataset(self, dataset_update): + def update_dataset(self, dataset_update, typ_data): self.dataset = dataset_update - self.component_class = globals()[self.component_class] - self.component = self.component_class(self.dataset) + self.typ_data = typ_data + self.component = self.component_class(self.dataset, self.typ_data) def update_instance(self, instance, enum, xtype, solver="g3"): self.instance = instance @@ -64,7 +66,7 @@ class View(): 'margin': '10px' } ), - html.Div(id='tree_filename')]) + html.Div(id='dataset_filename')]) self.instance_upload = html.Div([ dcc.Upload( diff --git a/utils.py b/utils.py index dcdf87631f5cb067c89fb201da501c5c6eaad55c..4b194caae78a7169bce243b206beb8ecee45789a 100644 --- a/utils.py +++ b/utils.py @@ -1,25 +1,28 @@ -import dash -import json -from dash import html -from dash import dcc -import dash_bootstrap_components as dbc - import base64 import io +import pickle import pandas as pd +import sklearn +from dash import html + def parse_contents_tree(contents, filename): content_type, content_string = contents.split(',') decoded = base64.b64decode(content_string) - try: - data = decoded.decode('utf-8') + try: + if '.dt' in filename: + data = decoded.decode('utf-8') + typ = 'dt' + elif '.pkl' in filename: + data = pickle.load(io.BytesIO(decoded)) + typ = 'pkl' except Exception as e: print(e) return html.Div([ 'There was an error processing this file.' ]) - return data + return data, typ def parse_contents_instance(contents, filename): content_type, content_string = contents.split(',') @@ -49,4 +52,4 @@ def extract_data(data): ml_type = data[i]['ml_type'] dict_components[ml_type] = data[i]['component'] - return names_models, dict_components \ No newline at end of file + return names_models, dict_components