From b7475710ca3a56e0e96557c911a80e69eb6e00a1 Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Mon, 28 Mar 2022 17:45:01 +0200
Subject: [PATCH] loading with or without dataset working, need treatment of
 fdf

---
 assets/header.css                             |   2 +-
 callbacks.py                                  |  47 +++++---
 .../DecisionTree/DecisionTreeComponent.py     | 103 +++++++++++++++---
 pages/application/DecisionTree/utils/data.py  |   2 +-
 pages/application/DecisionTree/utils/dtree.py |  12 +-
 .../DecisionTree/utils/upload_tree.py         |  19 ++--
 pages/application/application.py              |  62 ++++++++---
 utils.py                                      |   8 +-
 8 files changed, 190 insertions(+), 65 deletions(-)

diff --git a/assets/header.css b/assets/header.css
index b7dcabb..8da0dd9 100644
--- a/assets/header.css
+++ b/assets/header.css
@@ -62,7 +62,7 @@ div.sidebar.col-3 {
 
 .sidebar .sidebar-dropdown{
     width: 100%;
-    height: 30px;
+    height: 40px;
     line-height: 30px;
     border-width: 1px;
     border-radius: 5px;
diff --git a/callbacks.py b/callbacks.py
index 0427370..04930d6 100644
--- a/callbacks.py
+++ b/callbacks.py
@@ -31,14 +31,15 @@ def register_callbacks(page_home, page_course, page_application, app):
 
     @app.callback(
         Output('pretrained_model_filename', 'children'),
+        Output('info_filename', 'children'),
         Output('instance_filename', 'children'),
         Output('graph', 'children'),
         Output('explanation', 'children'),
         Input('ml_model_choice', 'value'),
         Input('ml_pretrained_model_choice', 'contents'),
         State('ml_pretrained_model_choice', 'filename'),
-        Input('model_dataset_choice', 'contents'),
-        State('model_dataset_choice', 'filename'),
+        Input('model_info_choice', 'contents'),
+        State('model_info_choice', 'filename'),
         Input('ml_instance_choice', 'contents'),
         State('ml_instance_choice', 'filename'),
         Input('number_explanations', 'value'),
@@ -47,62 +48,67 @@ def register_callbacks(page_home, page_course, page_application, app):
         Input('expl_choice', 'value'),
         prevent_initial_call=True
     )
-    def update_ml_type(value_ml_model, pretrained_model_contents, pretrained_model_filename, model_dataset, model_dataset_filename, instance_contents, instance_filename, enum, xtype, solver, expl_choice):
+    def update_ml_type(value_ml_model, pretrained_model_contents, pretrained_model_filename, model_info, model_info_filename, \
+                        instance_contents, instance_filename, enum, xtype, solver, expl_choice):
         ctx = dash.callback_context
         if ctx.triggered:
             ihm_id = ctx.triggered[0]['prop_id'].split('.')[0]
             model_application = page_application.model
             if ihm_id == 'ml_model_choice' :
                     model_application.update_ml_model(value_ml_model)
-                    return None, None, None, None
+                    return None, None, None, None, None
 
             elif ihm_id == 'ml_pretrained_model_choice':
                 if model_application.ml_model is None :
                     raise PreventUpdate
                 graph = parse_contents_graph(pretrained_model_contents, pretrained_model_filename)
                 model_application.update_pretrained_model(graph)
-                return pretrained_model_filename, None, None, None
+                if not model_application.add_info :
+                    model_application.update_pretrained_model_layout()
+                    return pretrained_model_filename, None, None, model_application.component.network, None
+                else :
+                    return pretrained_model_filename, None, None, None, None
 
-            elif ihm_id == 'model_dataset_choice':
+            elif ihm_id == 'model_info_choice':
                 if model_application.ml_model is None :
                     raise PreventUpdate
-                model_dataset = parse_contents_data(model_dataset, model_dataset_filename)
-                model_application.update_pretrained_model_dataset(model_dataset)
-                return pretrained_model_filename, None, model_application.component.network, None
+                model_info = parse_contents_data(model_info, model_info_filename)
+                model_application.update_pretrained_model_layout_with_info(model_info, model_info_filename)
+                return pretrained_model_filename, model_info_filename, None, model_application.component.network, None
 
             elif ihm_id == 'ml_instance_choice' :
                 if model_application.ml_model is None or model_application.pretrained_model is None :
                     raise PreventUpdate
                 instance = parse_contents_instance(instance_contents, instance_filename)
                 model_application.update_instance(instance, enum, xtype)
-                return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation   
+                return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation   
 
             elif ihm_id == 'number_explanations' :
                 if model_application.ml_model is None or model_application.pretrained_model is None or model_application.instance is None:
                     raise PreventUpdate
                 instance = parse_contents_instance(model_application.instance, instance_filename)
                 model_application.update_instance(instance, enum, xtype)
-                return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation   
+                return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation   
 
             elif ihm_id == 'explanation_type' :
                 if model_application.ml_model is None or model_application.pretrained_model is None or model_application.instance is None:
                     raise PreventUpdate
                 instance = parse_contents_instance(model_application.instance, instance_filename)
                 model_application.update_instance(instance, enum, xtype)
-                return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation
+                return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation
             
             elif ihm_id == 'solver_sat' :
                 if model_application.ml_model is None or model_application.pretrained_model is None or model_application.instance is None:
                     raise PreventUpdate
                 instance = parse_contents_instance(model_application.instance, instance_filename)
                 model_application.update_instance(instance, enum, xtype, solver=solver)
-                return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation             
+                return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation             
             
             elif ihm_id == 'expl_choice' :
                 if instance_contents is None :
                     raise PreventUpdate
                 model_application.update_expl(expl_choice)
-                return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation             
+                return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation             
 
 
     @app.callback(
@@ -125,3 +131,16 @@ def register_callbacks(page_home, page_course, page_application, app):
             for i in range (len(model_application.list_expls)):
                 options[str(model_application.list_expls[i])] = model_application.list_expls[i]
             return False, False, False, options
+
+    @app.callback(
+        Output('model_info_choice', 'disabled'),
+        Input('add_info_model_choice', 'value'),
+        prevent_initial_call=True
+    )
+    def add_model_info(add_info_model_choice):
+        model_application = page_application.model
+        model_application.update_info_needed(add_info_model_choice)
+        if add_info_model_choice==1:
+            return False
+        else :
+            return True
diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py
index 020a956..a1e2363 100644
--- a/pages/application/DecisionTree/DecisionTreeComponent.py
+++ b/pages/application/DecisionTree/DecisionTreeComponent.py
@@ -15,34 +15,105 @@ from pages.application.DecisionTree.utils.dtviz import (visualize,
 
 class DecisionTreeComponent():
 
-    def __init__(self, tree, dataset):
+    def __init__(self, tree, info=None, type_info=''):
 
-        data = Data(dataset)
-        fvmap = data.mapping_features()
-        
-        try:
-            feature_names = tree.feature_names_in_
-        except:
-            print("You did not dump the model with the features names")
-            feature_names = [f'f{i}' for i in range(tree.n_features_in_)]
-
-        self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_)
-        self.dt_format, self.map, features_names_mapping = self.uploaded_dt.dump(fvmap, feat_names=feature_names)
-        self.dt = DecisionTree(from_dt=self.dt_format, mapfile = self.map)
 
+        if info is not None and '.csv' in type_info: 
+            self.categorical = True
+            data = Data(info)
+            fvmap = data.mapping_features()
+            feature_names = data.names[:-1]
+
+            self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_)
+            self.dt_format, self.map, features_names_mapping = self.uploaded_dt.dump(fvmap, feat_names=feature_names)
+            self.mapping_instance = self.create_fvmap_inverse_with_info(features_names_mapping)
+
+
+        elif info is not None and '.txt' in type_info :
+            self.categorical = True
+            fvmap = {}
+
+            self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_)
+            self.dt_format, self.map, features_names_mapping = self.uploaded_dt.dump(fvmap, feat_names=feature_names)
+            self.mapping_instance = self.create_fvmap_inverse_with_info(features_names_mapping)
+
+        else : 
+            self.categorical = False
+            try:
+                feature_names = tree.feature_names_in_
+            except:
+                feature_names = [f'f{i}' for i in range(tree.n_features_in_)]
+
+            self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_)
+            self.dt_format, self.map, features_names_mapping = self.uploaded_dt.convert_dt(feat_names=feature_names)
+            self.mapping_instance = self.create_fvmap_inverse_threashold(features_names_mapping)
+
+        self.dt = DecisionTree(from_dt=self.dt_format, mapfile = self.map)
         dot_source = visualize(self.dt)
         self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz(dot_source=dot_source, style = {"width": "60%",
                                                                                                                 "height": "90%",
                                                                                                                 "background-color": "transparent"}))]
         self.explanation = []
 
+
+    def create_fvmap_inverse_with_info(self, features_names_mapping) :
+        mapping_instance = {}
+        for feat in features_names_mapping :
+            feat_dic = {}
+            feature_description = feat.split(',')
+            name_feat, id_feat =  feature_description[1].split(':')
+
+            for mapping in feature_description[2:]:
+                real_value, mapped_value = mapping.split(':')
+                feat_dic[np.float32(real_value)] = int(mapped_value)
+            mapping_instance[name_feat] = feat_dic
+
+        return mapping_instance
+
+
+    def create_fvmap_inverse_threashold(self, features_names_mapping) :
+        mapping_instance = {}
+        for feat in features_names_mapping :
+            feature_description = feat.split(',')
+            name_feat, id_feat =  feature_description[1].split(':')
+            mapping_instance[name_feat] = float(feature_description[2].split(':')[0])
+
+        return mapping_instance
+
+
+    def translate_instance(self, instance):
+        def translate_instance_categorical(instance):
+            instance_translated = []
+            for feat, real_value in instance :
+                instance_translated.append((feat, self.mapping_instance[feat][real_value]))
+            return instance_translated
+        
+        def translate_instance_threasholds(instance):
+            instance_translated = []
+            for feat, real_value in instance :
+                try: 
+                    if real_value <= self.mapping_instance[feat]:
+                        instance_translated.append((feat, 0))
+                    else : 
+                        instance_translated.append((feat, 1))
+                except:
+                    instance_translated.append((feat, real_value))
+            return instance_translated
+
+        if self.categorical :
+            return translate_instance_categorical(instance)
+        else : 
+            return translate_instance_threasholds(instance)
+
+
     def update_with_explicability(self, instance, enum, xtype, solver) :
        
+        instance_translated = self.translate_instance(instance)
         self.explanation = []
         list_explanations_path=[]
-        explanation = self.dt.explain(instance, enum=enum, xtype = xtype, solver=solver)
+        explanation = self.dt.explain(instance_translated, enum=enum, xtype = xtype, solver=solver)
 
-        dot_source = visualize_instance(self.dt, instance)
+        dot_source = visualize_instance(self.dt, instance_translated)
         self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz(
             dot_source=dot_source, style = {"width": "50%",
                                              "height": "80%",
@@ -67,7 +138,7 @@ class DecisionTreeComponent():
         return list_explanations_path
 
     def draw_explanation(self, instance, expl) :
-        print(expl)
+        instance = self.translate_instance(instance)
         dot_source = visualize_expl(self.dt, instance, expl)
         self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz(
                                 dot_source=dot_source, 
diff --git a/pages/application/DecisionTree/utils/data.py b/pages/application/DecisionTree/utils/data.py
index e719b59..d23ddb6 100644
--- a/pages/application/DecisionTree/utils/data.py
+++ b/pages/application/DecisionTree/utils/data.py
@@ -99,4 +99,4 @@ class Data(object):
                 for j, v in enumerate(sorted(self.feats[i])):
                     fvmap[f'f{i}'][j+m] = (self.names[i], False, v)                                
             
-        return fvmap          
\ No newline at end of file
+        return fvmap
diff --git a/pages/application/DecisionTree/utils/dtree.py b/pages/application/DecisionTree/utils/dtree.py
index bbd6046..1e7e270 100644
--- a/pages/application/DecisionTree/utils/dtree.py
+++ b/pages/application/DecisionTree/utils/dtree.py
@@ -332,15 +332,15 @@ class DecisionTree():
         """
             Compute a given number of explanations.
         """
-        
-        self.feids = {f[0]: i for i, f in enumerate(inst)}
-
-        path, term, depth = self.execute(inst, pathlits)
-    
         #contaiins all the elements for explanation
         explanation_dic = {}
         #instance plotting
-        explanation_dic["Instance : "] = str([self.fvmap[inst[i]] for i in range (len(inst))])
+        explanation_dic["Instance : "] = str([str(inst[i]) for i in range (len(inst))])
+
+        self.feids = {f'f{i}': i for i, f in enumerate(inst)}
+        inst = [(f'f{i}', int(inst[i][1])) for i,f in enumerate(inst)]
+
+        path, term, depth = self.execute(inst, pathlits)
 
         #decision path
         decision_path_str = 'IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[inst[self.feids[self.nodes[n].feat]]] for n in path]), term)
diff --git a/pages/application/DecisionTree/utils/upload_tree.py b/pages/application/DecisionTree/utils/upload_tree.py
index ea82f0a..b3e1297 100644
--- a/pages/application/DecisionTree/utils/upload_tree.py
+++ b/pages/application/DecisionTree/utils/upload_tree.py
@@ -239,10 +239,11 @@ class UploadedDecisionTree:
     
             
         if feat_names is not None:
-            features_names_mapping = ''
+            features_names_mapping = []
             for i,fid in enumerate(feat_names):
-                f=f'f{i}'
-                features_names_mapping += f'T:C,{fid}:{f},'+",".join([f'{fvmap[f][v][2]}:{v}' for v in fvmap[f] if(fvmap[f][v][1])])+'\n'
+                feat=f'f{i}'
+                f = f'T:C,{fid}:{feat},'+",".join([f'{fvmap[feat][v][2]}:{v}' for v in fvmap[feat] if(fvmap[feat][v][1])])
+                features_names_mapping.append(f)
 
         return dt, map, features_names_mapping    
 
@@ -299,13 +300,15 @@ class UploadedDecisionTree:
                 map += f"\n{f} {j} <={t}"
             map += f"\n{f} {j+1} >{t}"  
 
+
         if feat_names is not None:
-            features_names_mapping = ''
+            features_names_mapping = []
             for i,fid in enumerate(feat_names):
-                f=f'f{i}'
-                if f in self.intvs:
-                    features_names_mapping += f'\n Categorical,{fid}:{f},'
-                    features_names_mapping += ",".join([f'{t}:{j}' for j,t in enumerate(self.intvs[f])])
+                feat=f'f{i}'
+                if feat in self.intvs:
+                    f = f'T:O,{fid}:{feat},'
+                    f += ",".join([f'{t}:{j}' for j,t in enumerate(self.intvs[feat])])                    
+                    features_names_mapping.append(f)
 
         return dt, map, features_names_mapping    
     
diff --git a/pages/application/application.py b/pages/application/application.py
index f5c7032..925a1b7 100644
--- a/pages/application/application.py
+++ b/pages/application/application.py
@@ -19,7 +19,9 @@ class Model():
         self.ml_model = ''
 
         self.pretrained_model = ''
-        self.model_dataset = ''
+
+        self.add_info = False
+        self.model_info = ''
 
         self.instance = ''
 
@@ -37,9 +39,15 @@ class Model():
     def update_pretrained_model(self, pretrained_model_update):
         self.pretrained_model = pretrained_model_update
 
-    def update_pretrained_model_dataset(self, model_dataset):
-        self.model_dataset = model_dataset
-        self.component = self.component_class(self.pretrained_model, self.model_dataset)
+    def update_info_needed(self, add_info):
+        self.add_info = add_info
+
+    def update_pretrained_model_layout(self):
+        self.component = self.component_class(self.pretrained_model)
+
+    def update_pretrained_model_layout_with_info(self, model_info, model_info_filename):
+        self.model_info = model_info
+        self.component = self.component_class(self.pretrained_model, self.model_info, model_info_filename)
 
     def update_instance(self, instance, enum, xtype, solver="g3"):
         self.instance = instance
@@ -57,7 +65,13 @@ class View():
         self.ml_menu_models = dcc.Dropdown(self.model.ml_models, 
                                             id='ml_model_choice',
                                             className="sidebar-dropdown")
-        
+
+        self.ml_library_used = dcc.Dropdown(options = [{'label': 'Scikit-learn ', 'value': "SKL"},
+                                                        {'label': 'ITI', 'value': "ITI"},
+                                                        {'label': 'IAI', 'value': "IAI"}], 
+                                            id='ml_library_choice',
+                                            className="sidebar-dropdown")
+
         self.pretrained_model_upload = html.Div([
                                     dcc.Upload(        
                                         id='ml_pretrained_model_choice',
@@ -69,15 +83,22 @@ class View():
                                     ),
                                     html.Div(id='pretrained_model_filename')])
 
-        self.model_dataset = html.Div([
+        self.add_model_info_choice = dcc.RadioItems(id="add_info_model_choice", 
+                                                    options = [{'label': 'Yes ', 'value': 1},
+                                                               {'label': 'No', 'value': 0}], 
+                                                    value=0, className="sidebar-dropdown")
+
+        self.model_info = html.Div([
                                     dcc.Upload(        
-                                        id='model_dataset_choice',
+                                        id='model_info_choice',
+                                        disabled=True,
                                         children=html.Div([
                                             'Drag and Drop or ',
                                             html.A('Select File')
                                         ]),
                                         className="upload"
-                                    )])
+                                    ),
+                                    html.Div(id='info_filename')])
 
         self.instance_upload = html.Div([
                                     dcc.Upload(        
@@ -97,17 +118,27 @@ class View():
                                     html.Br(),
                                     self.ml_menu_models,
                                     html.Hr(),
+                                    html.Label("Choose the Machine Learning library used :"),
+                                    html.Br(),
+                                    self.ml_library_used,
+                                    html.Hr(),
                                     html.Label("Choose the pretrained model : "),
                                     html.Br(),
                                     self.pretrained_model_upload, 
                                     html.Hr(),
-                                    html.Label("Choose the pretrained model dataset : "),
-                                    self.model_dataset,
+                                    html.Label("Do you wish to upload more info for your model ? : "),
+                                    html.Br(),
+                                    self.add_model_info_choice,
                                     html.Hr(),
-                                    html.Label("Choose the instance to explain : "),
+                                    html.Label("Choose the pretrained model dataset (csv) or feature definition file (txt): "),
                                     html.Br(),
-                                    self.instance_upload,
+                                    self.model_info,
                                     html.Hr(),
+                                    html.Label("Choose the instance to explain : "),
+                                    html.Br(),
+                                    self.instance_upload], className="sidebar"),
+                                dcc.Tab(label='Advanced Parameters', children = [ 
+                                    html.Br(),
                                     html.Label("Choose the number of explanations : "),
                                     html.Br(),
                                     dcc.Input(
@@ -124,17 +155,14 @@ class View():
                                         options={'AXp' : "Abductive Explanation", 'CXp': "Contrastive explanation"},
                                         value = ['AXp', 'CXp'],
                                         className="sidebar-dropdown",
-                                        inline=True)], className="sidebar"),
-                                dcc.Tab(label='Advanced Parameters', children = [  
+                                        inline=True),
                                     html.Hr(),
                                     html.Label("Choose the SAT solver : "),
                                     html.Br(),
                                     dcc.Dropdown(['g3', 'g4', 'lgl', 'mcb', 'mcm', 'mpl', 'm22', 'mc', 'mgh'], 'g3', id='solver_sat') 
-                                ], className="sidebar")
-                        ])
+                                ], className="sidebar")])
                             
  
-
         self.expl_choice = dcc.Dropdown(self.model.list_expls,
                                         id='expl_choice',  
                                         className="dropdown")
diff --git a/utils.py b/utils.py
index 4b38288..b342910 100644
--- a/utils.py
+++ b/utils.py
@@ -27,6 +27,8 @@ def parse_contents_data(contents, filename):
     try:        
         if '.csv' in filename:
             data = decoded.decode('utf-8')
+        if '.txt' in filename:
+            data = decoded.decode('utf-8')
     except Exception as e:
         print(e)
         return html.Div([
@@ -39,9 +41,11 @@ def parse_contents_instance(contents, filename):
     content_type, content_string = contents.split(',')
     decoded = base64.b64decode(content_string)
     try:
-        if 'csv' in filename:
+        if '.csv' in filename:
             data = decoded.decode('utf-8')
-        elif 'txt' in filename:
+        elif '.txt' in filename:
+            data = decoded.decode('utf-8')       
+        elif '.json' in filename:
             data = decoded.decode('utf-8')
         else : 
             data = decoded.decode('utf-8')
-- 
GitLab