From 8b31c67631aa8ce02d6b2b1b54bab5549f069f33 Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Fri, 4 Mar 2022 11:07:14 +0100
Subject: [PATCH] integration of only pickle file working, need to work on the
 model when categorical variable

---
 .gitignore                                    |   6 +-
 .../DecisionTree/DecisionTreeComponent.py     |   2 +
 pages/application/DecisionTree/utils/dtree.py | 278 +++++-------------
 pages/application/DecisionTree/utils/dtviz.py |  73 ++---
 utils.py                                      |   5 +-
 5 files changed, 126 insertions(+), 238 deletions(-)

diff --git a/.gitignore b/.gitignore
index 4a1ff58..ac7f262 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,4 +3,8 @@ __pycache__
 pages/application/DecisionTree/utils/__pycache__
 pages/application/DecisionTree/__pycache__
 pages/application/__pycache__
-decision_tree_classifier_20170212.pkl
\ No newline at end of file
+decision_tree_classifier_20170212.pkl
+push_command
+adult.pkl
+adult_data_00000.inst
+iris_00000.txt
\ No newline at end of file
diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py
index 6ecb24c..93716c0 100644
--- a/pages/application/DecisionTree/DecisionTreeComponent.py
+++ b/pages/application/DecisionTree/DecisionTreeComponent.py
@@ -5,6 +5,7 @@ import dash_interactive_graphviz
 
 import os.path
 from os import path
+import numpy as np 
 
 class DecisionTreeComponent():
 
@@ -30,6 +31,7 @@ class DecisionTreeComponent():
 
     def update_with_explicability(self, instance, enum, xtype, solver) :
         instance = str(instance).strip().split(',')
+        instance = list(map(lambda i: tuple([i[0], np.float32(i[1])]), [i.split('=') for i in instance]))
 
         dot_source = visualize_instance(self.dt, instance)
         self.network = dash_interactive_graphviz.DashInteractiveGraphviz(
diff --git a/pages/application/DecisionTree/utils/dtree.py b/pages/application/DecisionTree/utils/dtree.py
index c6c6bc1..2a8b75f 100644
--- a/pages/application/DecisionTree/utils/dtree.py
+++ b/pages/application/DecisionTree/utils/dtree.py
@@ -17,6 +17,8 @@ from pysat.card import *
 from pysat.examples.hitman import Hitman
 from pysat.formula import CNF, IDPool
 from pysat.solvers import Solver
+import sklearn
+from torch import threshold
 
 try:  # for Python2
     from cStringIO import StringIO
@@ -32,7 +34,7 @@ class Node():
         Node class.
     """
 
-    def __init__(self, feat='', vals=[], threshold=None):
+    def __init__(self, feat='', vals=None, threshold=None, children_left= None, children_right=None):
         """
             Constructor.
         """
@@ -40,8 +42,10 @@ class Node():
         self.feat = feat
         if threshold is not None :
             self.threshold = threshold
+            self.children_left = 0
+            self.children_right = 0
         else : 
-            self.vals = vals
+            self.vals = {}
 
 
 #
@@ -51,12 +55,13 @@ class DecisionTree():
         Simple decision tree class.
     """
 
-    def __init__(self, from_dt=None, from_pickle=None, verbose=0):
+    def __init__(self, from_pickle=None, verbose=0):
         """
             Constructor.
         """
 
         self.verbose = verbose
+        self.typ=""
 
         self.nof_nodes = 0
         self.nof_terms = 0
@@ -66,26 +71,17 @@ class DecisionTree():
         self.paths = {}
         self.feats = []
         self.feids = {}
-        self.fdoms = {}
-        self.fvmap = {}
 
-        # OHE mapping
-        OHEMap = collections.namedtuple('OHEMap', ['dir', 'opp'])
-        self.ohmap = OHEMap(dir={}, opp={})
-
-        if from_dt:
-            self.from_dt(from_dt)     
-        elif from_pickle:
+        if from_pickle:
+            self.typ="pkl"
+            self.tree_ = ''
             self.from_pickle_file(from_pickle)
 
-        for f in self.feats:
-            for v in self.fdoms[f]:
-                self.fvmap[tuple([f, v])] = '{0}={1}'.format(f, v)
-
     #problem de feature names et problem de vals dans node
     def from_pickle_file(self, tree):
         #help(_tree.Tree)
-        tree_ = tree.tree_
+        self.tree_ = tree.tree_
+        print(sklearn.tree.export_text(tree))
         try:
             feature_names = tree.feature_names_in_
         except:
@@ -93,51 +89,37 @@ class DecisionTree():
             feature_names = [str(i) for i in range(tree.n_features_in_)] 
 
         class_names = tree.classes_
-
-        self.nodes = collections.defaultdict(lambda: Node(feat='', vals={}))
+        self.nodes = collections.defaultdict(lambda: Node(feat='', threshold=int(0), children_left=int(0), children_right=int(0)))
         self.terms={}
-        self.nof_nodes = tree_.node_count
-        self.nof_terms = 0
+        self.nof_nodes = self.tree_.node_count
         self.root_node = 0
+        self.feats = feature_names
 
         feature_name = [
             feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
-            for i in tree_.feature]
+            for i in self.tree_.feature]
 
-        def recurse(feats, fdoms, node):
-            if tree_.feature[node] != _tree.TREE_UNDEFINED:
+        def recurse(node):
+            if self.tree_.feature[node] != _tree.TREE_UNDEFINED:
                 name = feature_name[node]
-                val = tree_.threshold[node]
+                val = self.tree_.threshold[node]
 
                 #faire une boucle for des vals ? 
                 self.nodes[int(node)].feat = name
-                self.nodes[int(node)].vals[int(np.round(val,4))] = int(tree_.children_left[node])
-
-                self.nodes[int(node)].feat = name
-                self.nodes[int(node)].vals[int(4854)] = int(tree_.children_right[node])
+                self.nodes[int(node)].threshold = np.round(val, 4)
+                self.nodes[int(node)].children_left = int(self.tree_.children_left[node]) 
+                self.nodes[int(node)].children_right = int(self.tree_.children_right[node])
 
-                feats.add(name)
-                fdoms[name].add(int(np.round(val,4)))
-                feats, fdoms = recurse(feats, fdoms, tree_.children_left[node])
-                fdoms[name].add(4854)
-                feats, fdoms = recurse(feats, fdoms, tree_.children_right[node])
+                recurse(self.tree_.children_left[node])
+                recurse(self.tree_.children_right[node])
 
             else:
-                self.terms[node] = class_names[np.argmax(tree_.value[node])]
-
-            return feats, fdoms
+                self.terms[node] = class_names[np.argmax(self.tree_.value[node])]
             
-        self.feats, self.fdoms = recurse(set([]), collections.defaultdict(lambda: set([])), self.root_node)
-
-        for parent in self.nodes:
-            conns = collections.defaultdict(lambda: set([]))
-            for val, child in self.nodes[parent].vals.items():
-                conns[child].add(val)
-            self.nodes[parent].vals = {frozenset(val): child for child, val in conns.items()}
+        recurse(self.root_node)
 
         self.feats = sorted(self.feats)
         self.feids = {f: i for i, f in enumerate(self.feats)}
-        self.fdoms = {f: sorted(self.fdoms[f]) for f in self.fdoms}
         self.nof_terms = len(self.terms)
         self.nof_nodes -= len(self.terms)
         self.nof_feats = len(self.feats)
@@ -145,70 +127,6 @@ class DecisionTree():
         self.paths = collections.defaultdict(lambda: [])
         self.extract_paths(root=self.root_node, prefix=[])
 
-    def from_dt(self, data):
-        """
-            Get the tree from a file pointer.
-        """
-
-        contents = StringIO(data)
-
-        lines = contents.readlines()
-
-        # filtering out comment lines (those that start with '#')
-        lines = list(filter(lambda l: not l.startswith('#'), lines))
-
-        # number of nodes
-        self.nof_nodes = int(lines[0].strip())
-
-        # root node
-        self.root_node = int(lines[1].strip())
-
-        # number of terminal nodes (classes)
-        self.nof_terms = len(lines[3][2:].strip().split())
-
-        # the ordered list of terminal nodes
-        self.terms = {}
-        for i in range(self.nof_terms):
-            nd, _, t = lines[i + 4].strip().split()
-            self.terms[int(nd)] = t #int(t)
-
-        # finally, reading the nodes
-        self.nodes = collections.defaultdict(lambda: Node(feat='', vals={}))
-        self.feats = set([])
-        self.fdoms = collections.defaultdict(lambda: set([]))
-        for line in lines[(4 + self.nof_terms):]:
-            # reading the tuple
-            nid, fid, fval, child = line.strip().split()
-
-            # inserting it in the nodes list
-            self.nodes[int(nid)].feat = fid
-            self.nodes[int(nid)].vals[int(fval)] = int(child)
-
-            # updating the list of features
-            self.feats.add(fid)
-
-            # updaing feature domains
-            self.fdoms[fid].add(int(fval))
-
-        # adding complex node connections into consideration
-        for n1 in self.nodes:
-            conns = collections.defaultdict(lambda: set([]))
-            for v, n2 in self.nodes[n1].vals.items():
-                conns[n2].add(v)
-            self.nodes[n1].vals = {frozenset(v): n2 for n2, v in conns.items()}
-
-        # simplifying the features and their domains
-        self.feats = sorted(self.feats)
-        self.feids = {f: i for i, f in enumerate(self.feats)}
-        self.fdoms = {f: sorted(self.fdoms[f]) for f in self.fdoms}
-
-        # here we assume all features are present in the tree
-        # if not, this value will be rewritten by self.parse_mapping()
-        self.nof_feats = len(self.feats)
-
-        self.paths = collections.defaultdict(lambda: [])
-        self.extract_paths(root=self.root_node, prefix=[])
-
     def extract_paths(self, root, prefix):
         """
             Traverse the tree and extract explicit paths.
@@ -220,63 +138,17 @@ class DecisionTree():
             self.paths[term].append(prefix)
         else:
             # select next node
-            feat, vals = self.nodes[root].feat, self.nodes[root].vals
-            for val in vals:
-                self.extract_paths(vals[val], prefix + [tuple([feat, val])])
-
-    def execute(self, inst, pathlits=False):
-        """
-            Run the tree and obtain the prediction given an input instance.
-        """
-
-        root = self.root_node
-        depth = 0
-        path = []
-
-        # this array is needed if we focus on the path's literals only
-        visited = [False for f in inst]
-
-        while not root in self.terms:
-            path.append(root)
-            feat, vals = self.nodes[root].feat, self.nodes[root].vals
-            visited[self.feids[feat]] = True
-            tval = inst[self.feids[feat]][1]
-            ###############
-            # assert(len(vals) == 2)
-            next_node = root
-            neq = None
-            for vs, dest in vals.items():
-                if tval in vs:
-                    next_node = dest
-                    break
-                else:
-                    for v in vs:
-                        if '!=' in self.fvmap[(feat, v)]:
-                            neq = dest
-                            break
-            else:
-                next_node = neq
-            # if tval not in vals:
-            #     # go to the False branch (!=)
-            #     for i in vals:
-            #         if "!=" in self.fvmap[(feat,i)]:
-            #             next_node = vals[i]
-            #             break
-            # else:
-            #     next_node = vals[tval]
-
-            assert (next_node != root)
-            ###############
-            root = next_node
-            depth += 1
-
-        if pathlits:
-            # filtering out non-visited literals
-            for i, v in enumerate(visited):
-                if not v:
-                    inst[i] = None
-
-        return path, self.terms[root], depth
+            feat, threshold, children_left, children_right = self.nodes[root].feat, self.nodes[root].threshold, self.nodes[root].children_left, self.nodes[root].children_right
+            self.extract_paths(children_left, prefix + [tuple([feat, "<=" + str(threshold)])])
+            self.extract_paths(children_right, prefix + [tuple([feat, ">"+ str(threshold)])])
+
+    def execute(self, inst):
+        inst = np.array([inst])
+        path = self.tree_.decision_path(inst)
+        term_id_node = self.tree_.apply(inst)
+        term_id_node = term_id_node[0]
+        path = path.indices[path.indptr[0] : path.indptr[0 + 1]]
+        return path, term_id_node
 
     def prepare_sets(self, inst, term):
         """
@@ -295,21 +167,16 @@ class DecisionTree():
                 to_hit = []
                 for item in path:
                     # if the instance disagrees with the path on this item
-                    if inst[self.feids[item[0]]] and not inst[self.feids[item[0]]][1] in item[1]:
-                        fv = inst[self.feids[item[0]]]
-                        if fv[0] in self.ohmap.opp:
-                            to_hit.append(tuple([self.ohmap.opp[fv[0]], None]))
-                        else:
-                            to_hit.append(fv)
-
-                to_hit = sorted(set(to_hit))
-                sets.append(tuple(to_hit))
-
-                if self.verbose:
-                    if self.verbose > 1:
-                        print('c trav. path: {0}'.format(path))
+                    if ("<="  in item[1] and (inst[item[0]] > np.float32(item[1][2:]))) or (">" in item[1] and (inst[item[0]] <= np.float32(item[1][1:]))) :
+                        if "<="  in item[1] :
+                            fv = tuple([item[0], str(inst[item[0]]), ">" , str(np.float32(item[1][2:]))])
+                        else :
+                            fv = tuple([item[0], str(inst[item[0]]) , "<=" , str(np.float32(item[1][1:]))])
+                        to_hit.append(fv)
 
-                    print('c set to hit: {0}'.format(to_hit))
+                if len(to_hit)>0 :
+                    to_hit = sorted(set(to_hit))
+                    sets.append(tuple(to_hit))
 
         # returning the set of sets with no duplicates
         return list(dict.fromkeys(sets))
@@ -319,25 +186,32 @@ class DecisionTree():
             Compute a given number of explanations.
         """
 
-        inst = list(map(lambda i: tuple([i[0], int(i[1])]), [i.split('=') for i in inst]))
+        inst_values = [np.float32(i[1]) for i in inst]
+        inst_dic = {}
+        for i in range(len(inst)):
+            inst_dic[inst[i][0]] = np.float32(inst[i][1])
         inst_orig = inst[:]
-        path, term, depth = self.execute(inst, pathlits)
-
-        explanation = str(inst) + "\n \n"
-        #print('c instance: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in inst_orig]), term))
-        #print(term)
-        explanation += 'c instance: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[ inst_orig[self.feids[self.nodes[n].feat]] ] for n in path]), term) + "\n"
-        explanation +='c path len:'+ str(depth)+ "\n \n \n"
+        path, term = self.execute(inst_values)
+    
+        explanation = str(inst_dic) + "\n \n"
+        decision_path_str = "c inst : IF : "
+        for node_id in path:
+            # continue to the next node if it is a leaf node
+            if term == node_id:
+                continue
 
-        if self.ohmap.dir:
-            f2v = {fv[0]: fv[1] for fv in inst}
+            decision_path_str +="(inst[{feature}] = {value}) {inequality} {threshold}) AND ".format(
+                                feature=self.nodes[node_id].feat,
+                                value=inst_dic[self.nodes[node_id].feat],
+                                inequality="<=" if inst_dic[self.nodes[node_id].feat] <= self.nodes[node_id].threshold else ">" ,
+                                threshold=self.nodes[node_id].threshold)
 
-            # updating fvmap for printing ohe features
-            for fo, fis in self.ohmap.dir.items():
-                self.fvmap[tuple([fo, None])] = '(' + ' AND '.join([self.fvmap[tuple([fi, f2v[fi]])] for fi in fis]) + ')'
+        decision_path_str += "THEN " + str(self.terms[term])
+        explanation += decision_path_str + "\n \n"
+        explanation +='c path len:'+ str(len(path))+ "\n \n \n"
 
         # computing the sets to hit
-        to_hit = self.prepare_sets(inst, term)
+        to_hit = self.prepare_sets(inst_dic, term)
 
         for type in xtype :
             if type == "AXp":
@@ -354,11 +228,14 @@ class DecisionTree():
             Enumerate abductive explanations.
         """
         explanation = ""
-        with Hitman(bootstrap_with=to_hit, solver=solver, htype=htype) as hitman:
+        with Hitman(bootstrap_with=to_hit, solver='m22', htype=htype) as hitman:
             expls = []
             for i, expl in enumerate(hitman.enumerate(), 1):
-                explanation += 'c expl: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])]), term) + "\n"
-
+                explanation += 'c expl: IF {0} THEN class={1}'.format(' AND '.join(["(inst[{feature}] = {value}) {inequality} {threshold})".format(feature=p[0], 
+                                                                                                        value=p[1], 
+                                                                                                        inequality=p[2], 
+                                                                                                        threshold=p[3]) 
+                                                                                                        for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))+ "\n"
                 expls.append(expl)
                 if i == enum:
                     break
@@ -388,9 +265,10 @@ class DecisionTree():
         expls = list(reduce(process_set, to_hit, []))
         explanation = ""
         for expl in expls:
-            explanation += 'c expl: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(self.fvmap[p]) for p in sorted(expl, key=lambda p: p[0])]), term)+ "\n"
-
-
+            explanation += 'c expl: IF {0} THEN class!={1}'.format(' OR '.join(["inst[{feature}] {inequality} {threshold})".format(feature=p[0], 
+                                                                                                        inequality="<=" if p[2]==">" else ">", 
+                                                                                                        threshold=p[3]) 
+                                                                                                        for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))+ "\n"
         explanation +='c nof expls:'+ str(len(expls))+ "\n"
         explanation +='c min expl:'+ str( min([len(e) for e in expls]))+ "\n"
         explanation +='c max expl:'+ str( max([len(e) for e in expls]))+ "\n"
diff --git a/pages/application/DecisionTree/utils/dtviz.py b/pages/application/DecisionTree/utils/dtviz.py
index d0abf06..11ac6fb 100755
--- a/pages/application/DecisionTree/utils/dtviz.py
+++ b/pages/application/DecisionTree/utils/dtviz.py
@@ -12,7 +12,8 @@
 #==============================================================================
 from pages.application.DecisionTree.utils.dtree import DecisionTree
 import pygraphviz
-
+import numpy as np
+import pandas as pd  
 #
 #==============================================================================
 def visualize(dt):
@@ -38,18 +39,22 @@ def visualize(dt):
         node.attr['shape'] = 'square'
         node.attr['fontsize'] = 13
 
-    # transitions
     for n1 in dt.nodes:
-        for v in dt.nodes[n1].vals:
-            n2 = dt.nodes[n1].vals[v]
-            g.add_edge(n1, n2)
-            edge = g.get_edge(n1, n2)
-            if len(v) == 1:
-                edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
-            else:
-                edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
-            edge.attr['fontsize'] = 10
-            edge.attr['arrowsize'] = 0.8
+        threshold = dt.nodes[n1].threshold
+
+        children_left = dt.nodes[n1].children_left
+        g.add_edge(n1, children_left)
+        edge = g.get_edge(n1, children_left)
+        edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold)
+        edge.attr['fontsize'] = 10
+        edge.attr['arrowsize'] = 0.8
+
+        children_right = dt.nodes[n1].children_right
+        g.add_edge(n1, children_right)
+        edge = g.get_edge(n1, children_right)
+        edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold)
+        edge.attr['fontsize'] = 10
+        edge.attr['arrowsize'] = 0.8
 
     # saving file
     g.layout(prog='dot')
@@ -61,8 +66,6 @@ def visualize_instance(dt, instance):
     """
         Visualize a DT with graphviz and plot the running instance.
     """
-    instance = list(map(lambda i: tuple([i[0], int(i[1])]), [i.split('=') for i in instance]))
-
     g = pygraphviz.AGraph(directed=True, strict=True)
     g.edge_attr['dir'] = 'forward'
     g.graph_attr['rankdir'] = 'TB'
@@ -82,30 +85,34 @@ def visualize_instance(dt, instance):
         node.attr['fontsize'] = 13
 
     #path that follows the instance - colored in blue
-    path, term, depth = dt.execute(instance)
+    instance = [np.float32(i[1]) for i in instance]
+    path, term_id_node = dt.execute(instance)
     edges_instance = []
     for i in range (len(path)-1) :
         edges_instance.append((path[i], path[i+1]))
-    edges_instance.append((path[-1],"term:"+term))
-
-    # transitions
+        
     for n1 in dt.nodes:
-        for v in dt.nodes[n1].vals:
-            n2 = dt.nodes[n1].vals[v]
-            n2_type = g.get_node(n2).attr['shape']
-            g.add_edge(n1, n2)
-            edge = g.get_edge(n1, n2)
-            if len(v) == 1:
-                edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
-            else:
-                edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
-            
-            #instance path in blue
-            if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): 
-                edge.attr['color'] = 'blue'
+        threshold = dt.nodes[n1].threshold
+
+        children_left = dt.nodes[n1].children_left
+        g.add_edge(n1, children_left)
+        edge = g.get_edge(n1, children_left)
+        edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold)
+        edge.attr['fontsize'] = 10
+        edge.attr['arrowsize'] = 0.8
+        #instance path in blue
+        if ((n1,children_left) in edges_instance): 
+            edge.attr['color'] = 'blue'
 
-            edge.attr['fontsize'] = 10
-            edge.attr['arrowsize'] = 0.8
+        children_right = dt.nodes[n1].children_right
+        g.add_edge(n1, children_right)
+        edge = g.get_edge(n1, children_right)
+        edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold)
+        edge.attr['fontsize'] = 10
+        edge.attr['arrowsize'] = 0.8
+        #instance path in blue
+        if ((n1,children_right) in edges_instance): 
+            edge.attr['color'] = 'blue'
 
     # saving file
     g.layout(prog='dot')
diff --git a/utils.py b/utils.py
index 4b194ca..27f0b5e 100644
--- a/utils.py
+++ b/utils.py
@@ -10,10 +10,7 @@ def parse_contents_tree(contents, filename):
     content_type, content_string = contents.split(',')
     decoded = base64.b64decode(content_string)
     try:        
-        if '.dt' in filename:
-            data = decoded.decode('utf-8')
-            typ = 'dt'
-        elif '.pkl' in filename:
+        if '.pkl' in filename:
             data = pickle.load(io.BytesIO(decoded))
             typ = 'pkl'
     except Exception as e:
-- 
GitLab