From 022452bcff7fd606a2be6fa0254392dc003070ab Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Fri, 25 Mar 2022 15:20:03 +0100
Subject: [PATCH] added dataset upload and now works with pkl file, however
 wrong form on inftance

---
 callbacks.py                                  |  17 ++-
 .../DecisionTree/DecisionTreeComponent.py     |  22 ++-
 pages/application/DecisionTree/utils/data.py  | 102 +++++++++++++
 pages/application/DecisionTree/utils/dtree.py |  28 ++--
 pages/application/DecisionTree/utils/dtviz.py |  11 +-
 .../DecisionTree/utils/upload_tree.py         | 141 ++++++++++++++++--
 pages/application/application.py              |  21 ++-
 utils.py                                      |  13 ++
 8 files changed, 308 insertions(+), 47 deletions(-)
 create mode 100644 pages/application/DecisionTree/utils/data.py

diff --git a/callbacks.py b/callbacks.py
index 95101df..0427370 100644
--- a/callbacks.py
+++ b/callbacks.py
@@ -4,7 +4,7 @@ from dash import Input, Output, State
 from dash.dependencies import Input, Output, State
 from dash.exceptions import PreventUpdate
 
-from utils import parse_contents_graph, parse_contents_instance
+from utils import parse_contents_graph, parse_contents_instance, parse_contents_data
 
 
 def register_callbacks(page_home, page_course, page_application, app):
@@ -37,6 +37,8 @@ def register_callbacks(page_home, page_course, page_application, app):
         Input('ml_model_choice', 'value'),
         Input('ml_pretrained_model_choice', 'contents'),
         State('ml_pretrained_model_choice', 'filename'),
+        Input('model_dataset_choice', 'contents'),
+        State('model_dataset_choice', 'filename'),
         Input('ml_instance_choice', 'contents'),
         State('ml_instance_choice', 'filename'),
         Input('number_explanations', 'value'),
@@ -45,7 +47,7 @@ def register_callbacks(page_home, page_course, page_application, app):
         Input('expl_choice', 'value'),
         prevent_initial_call=True
     )
-    def update_ml_type(value_ml_model, pretrained_model_contents, pretrained_model_filename, instance_contents, instance_filename, enum, xtype, solver, expl_choice):
+    def update_ml_type(value_ml_model, pretrained_model_contents, pretrained_model_filename, model_dataset, model_dataset_filename, instance_contents, instance_filename, enum, xtype, solver, expl_choice):
         ctx = dash.callback_context
         if ctx.triggered:
             ihm_id = ctx.triggered[0]['prop_id'].split('.')[0]
@@ -57,8 +59,15 @@ def register_callbacks(page_home, page_course, page_application, app):
             elif ihm_id == 'ml_pretrained_model_choice':
                 if model_application.ml_model is None :
                     raise PreventUpdate
-                tree = parse_contents_graph(pretrained_model_contents, pretrained_model_filename)
-                model_application.update_pretrained_model(tree, pretrained_model_filename)
+                graph = parse_contents_graph(pretrained_model_contents, pretrained_model_filename)
+                model_application.update_pretrained_model(graph)
+                return pretrained_model_filename, None, None, None
+
+            elif ihm_id == 'model_dataset_choice':
+                if model_application.ml_model is None :
+                    raise PreventUpdate
+                model_dataset = parse_contents_data(model_dataset, model_dataset_filename)
+                model_application.update_pretrained_model_dataset(model_dataset)
                 return pretrained_model_filename, None, model_application.component.network, None
 
             elif ihm_id == 'ml_instance_choice' :
diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py
index 8c795d7..020a956 100644
--- a/pages/application/DecisionTree/DecisionTreeComponent.py
+++ b/pages/application/DecisionTree/DecisionTreeComponent.py
@@ -6,6 +6,7 @@ import dash_interactive_graphviz
 import numpy as np
 from dash import dcc, html
 from pages.application.DecisionTree.utils.upload_tree import UploadedDecisionTree
+from pages.application.DecisionTree.utils.data import Data
 from pages.application.DecisionTree.utils.dtree import DecisionTree
 
 from pages.application.DecisionTree.utils.dtviz import (visualize,
@@ -14,18 +15,20 @@ from pages.application.DecisionTree.utils.dtviz import (visualize,
 
 class DecisionTreeComponent():
 
-    def __init__(self, tree, filename_tree):
+    def __init__(self, tree, dataset):
 
+        data = Data(dataset)
+        fvmap = data.mapping_features()
+        
         try:
             feature_names = tree.feature_names_in_
         except:
             print("You did not dump the model with the features names")
             feature_names = [f'f{i}' for i in range(tree.n_features_in_)]
-        self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', filename_tree, maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_)
 
-        #need a function that takes as input UploadedDecisionTree and gives DecisionTree
-        self.dt_format, self.map, features_names_mapping = self.uploaded_dt.convert_dt(feat_names=feature_names)
-        self.dt = DecisionTree(from_dt=self.dt_format, mapfile = self.map, mapping_features=features_names_mapping)
+        self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_)
+        self.dt_format, self.map, features_names_mapping = self.uploaded_dt.dump(fvmap, feat_names=feature_names)
+        self.dt = DecisionTree(from_dt=self.dt_format, mapfile = self.map)
 
         dot_source = visualize(self.dt)
         self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz(dot_source=dot_source, style = {"width": "60%",
@@ -34,6 +37,11 @@ class DecisionTreeComponent():
         self.explanation = []
 
     def update_with_explicability(self, instance, enum, xtype, solver) :
+       
+        self.explanation = []
+        list_explanations_path=[]
+        explanation = self.dt.explain(instance, enum=enum, xtype = xtype, solver=solver)
+
         dot_source = visualize_instance(self.dt, instance)
         self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz(
             dot_source=dot_source, style = {"width": "50%",
@@ -41,9 +49,6 @@ class DecisionTreeComponent():
                                             "background-color": "transparent"}
         ))]
 
-        self.explanation = []
-        list_explanations_path=[]
-        explanation = self.dt.explain(instance, enum=enum, xtype = xtype, solver=solver)
 
         #Creating a clean and nice text component
         for k in explanation.keys() :
@@ -62,6 +67,7 @@ class DecisionTreeComponent():
         return list_explanations_path
 
     def draw_explanation(self, instance, expl) :
+        print(expl)
         dot_source = visualize_expl(self.dt, instance, expl)
         self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz(
                                 dot_source=dot_source, 
diff --git a/pages/application/DecisionTree/utils/data.py b/pages/application/DecisionTree/utils/data.py
new file mode 100644
index 0000000..e719b59
--- /dev/null
+++ b/pages/application/DecisionTree/utils/data.py
@@ -0,0 +1,102 @@
+#!/usr/bin/env python
+#-*- coding:utf-8 -*-
+##
+## data.py
+##
+##  Created on: Sep 20, 2017
+##      Author: Alexey Ignatiev, Nina Narodytska
+##      E-mail: aignatiev@ciencias.ulisboa.pt, narodytska@vmware.com
+##
+
+#
+#==============================================================================
+from __future__ import print_function
+import collections
+import itertools
+import os, pickle
+import six
+import gzip
+from six.moves import range
+import numpy as np
+import pandas as pd
+
+#from  sklearn.preprocessing import OneHotEncoder
+from sklearn.model_selection import train_test_split
+
+
+#
+#==============================================================================
+class Data(object):
+    """
+        Class for representing data (transactions).
+    """
+
+    def __init__(self, data, separator=','):
+        """
+            Constructor and parser.
+        """
+        self.names = None
+        self.nm2id = None
+        self.feats = None
+        self.targets = None
+        self.samples = None
+
+        self.parse(data, separator)
+           
+    def parse(self, data, separator):
+        """
+            Parse input file.
+        """
+
+        # reading data set from file
+        lines = data.split('\n')
+
+        # reading preamble
+        self.names = [name.replace('"','').strip() for name in lines[0].strip().split(separator)]
+        self.feats = [set([]) for n in self.names[:-1]]
+        self.targets = set([])
+        
+        lines = lines[1:]
+
+        # filling name to id mapping
+        self.nm2id = {name: i for i, name in enumerate(self.names)}
+
+        self.nonbin2bin = {}
+        for name in self.nm2id:
+            spl = name.rsplit(':',1)
+            if (spl[0] not in self.nonbin2bin):
+                self.nonbin2bin[spl[0]] = [name]
+            else:
+                self.nonbin2bin[spl[0]].append(name)
+
+        # reading training samples
+        self.samples =  []
+
+        for line, w in six.iteritems(collections.Counter(lines)):
+            inst = [v.strip() for v in line.strip().split(separator)]
+            self.samples.append(inst)
+            for i, v in enumerate(inst[:-1]):
+                if v:
+                    self.feats[i].add(str(v)) 
+            assert(inst[-1])
+            self.targets.add(str(inst[-1]))
+        
+        self.nof_feats = len(self.names[:-1])            
+        
+    def mapping_features(self):
+        """
+            feature-value mapping
+        """
+        fvmap = {}
+        
+        for i in range(self.nof_feats):
+            fvmap[f'f{i}'] = dict() 
+            for j, v in enumerate(sorted(self.feats[i])):
+                fvmap[f'f{i}'][j] = (self.names[i], True, v)
+            
+            if len(self.feats[i]) > 2:
+                m = len(self.feats[i])
+                for j, v in enumerate(sorted(self.feats[i])):
+                    fvmap[f'f{i}'][j+m] = (self.names[i], False, v)                                
+            
+        return fvmap          
\ No newline at end of file
diff --git a/pages/application/DecisionTree/utils/dtree.py b/pages/application/DecisionTree/utils/dtree.py
index 837decf..bbd6046 100644
--- a/pages/application/DecisionTree/utils/dtree.py
+++ b/pages/application/DecisionTree/utils/dtree.py
@@ -45,7 +45,7 @@ class DecisionTree():
         Simple decision tree class.
     """
 
-    def __init__(self, from_dt=None, mapfile=None, mapping_features=None, verbose=0):
+    def __init__(self, from_dt=None, mapfile=None, verbose=0):
         """
             Constructor.
         """
@@ -63,8 +63,6 @@ class DecisionTree():
         self.fdoms = {}
         self.fvmap = {}
 
-        self.features_names = mapping_features
-
         # OHE mapping
         OHEMap = collections.namedtuple('OHEMap', ['dir', 'opp'])
         self.ohmap = OHEMap(dir={}, opp={})
@@ -79,7 +77,6 @@ class DecisionTree():
                 for v in self.fdoms[f]:
                     self.fvmap[tuple([f, v])] = '{0}={1}'.format(f, v)
 
-
     def from_dt(self, data):
         """
             Get the tree from a file pointer.
@@ -134,7 +131,7 @@ class DecisionTree():
 
         # simplifying the features and their domains
         self.feats = sorted(self.feats)
-        self.feids = {f: i for i, f in enumerate(self.feats)}
+        #self.feids = {f: i for i, f in enumerate(self.feats)}
         self.fdoms = {f: sorted(self.fdoms[f]) for f in self.fdoms}
 
         # here we assume all features are present in the tree
@@ -171,9 +168,13 @@ class DecisionTree():
             # skipping the first comment line if necessary
             lines = lines[1:]
 
+        # number of features
+        self.nof_feats = int(lines[0].strip())
+        self.feids = {}
+
         for line in lines[1:]:
             feat, val, real = line.split()
-            self.fvmap[tuple([feat, int(val)])] = '{0}{1}'.format(self.features_names[feat], real)
+            self.fvmap[tuple([feat, int(val)])] = '{0}{1}'.format(feat, real)
             #if feat not in self.feids:
             #    self.feids[feat] = len(self.feids)
 
@@ -331,19 +332,18 @@ class DecisionTree():
         """
             Compute a given number of explanations.
         """
+        
+        self.feids = {f[0]: i for i, f in enumerate(inst)}
 
-        inst_dic = {}
-        for i in range(len(inst)):
-            inst_dic[inst[i][0]] = np.float32(inst[i][1])
         path, term, depth = self.execute(inst, pathlits)
     
         #contaiins all the elements for explanation
         explanation_dic = {}
         #instance plotting
-        explanation_dic["Instance : "] = str(inst_dic)
+        explanation_dic["Instance : "] = str([self.fvmap[inst[i]] for i in range (len(inst))])
 
         #decision path
-        decision_path_str = 'IF {0} THEN class={1}'.format(' AND '.join([str(inst[self.feids[self.nodes[n].feat]]) for n in path]), term)
+        decision_path_str = 'IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[inst[self.feids[self.nodes[n].feat]]] for n in path]), term)
         explanation_dic["Decision path of instance : "] = decision_path_str
         explanation_dic["Decision path length : "] = 'Path length is :'+ str(depth)
 
@@ -375,8 +375,8 @@ class DecisionTree():
         with Hitman(bootstrap_with=to_hit, solver='m22', htype=htype) as hitman:
             expls = []
             for i, expl in enumerate(hitman.enumerate(), 1):
-                list_expls.append([ str(p[0]) + "=" + str(p[1]) for p in expl])
-                list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join([str(p) for p in sorted(expl, key=lambda p: p[0])]), term))
+                list_expls.append([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])])
+                list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])]), term))
 
                 expls.append(expl)
                 if i == enum:
@@ -409,7 +409,7 @@ class DecisionTree():
         list_expls_str = []
         explanation = {}
         for expl in expls:
-            list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(str(p)) for p in sorted(expl, key=lambda p: p[0])]), term))
+            list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(self.fvmap[p]) for p in sorted(expl, key=lambda p: p[0])]), term))
 
         explanation["List of contrastive explanation(s)"] = list_expls_str
         explanation["Number of contrastive explanation(s) : "]=str(len(expls))
diff --git a/pages/application/DecisionTree/utils/dtviz.py b/pages/application/DecisionTree/utils/dtviz.py
index 681ccca..0c940c6 100755
--- a/pages/application/DecisionTree/utils/dtviz.py
+++ b/pages/application/DecisionTree/utils/dtviz.py
@@ -50,7 +50,7 @@ def visualize(dt):
 
     # non-terminal nodes
     for n in dt.nodes:
-        g.add_node(n, label=dt.features_names[dt.nodes[n].feat])
+        g.add_node(n, label=dt.nodes[n].feat)
         node = g.get_node(n)
         node.attr['shape'] = 'circle'
         node.attr['fontsize'] = 13
@@ -98,7 +98,7 @@ def visualize_instance(dt, instance):
 
     # non-terminal nodes
     for n in dt.nodes:
-        g.add_node(n, label=dt.features_names[dt.nodes[n].feat])
+        g.add_node(n, label=dt.nodes[n].feat)
         node = g.get_node(n)
         node.attr['shape'] = 'circle'
         node.attr['fontsize'] = 13
@@ -156,7 +156,7 @@ def visualize_expl(dt, instance, expl):
 
     # non-terminal nodes
     for n in dt.nodes:
-        g.add_node(n, label=dt.features_names[dt.nodes[n].feat])
+        g.add_node(n, label=dt.nodes[n].feat)
         node = g.get_node(n)
         node.attr['shape'] = 'circle'
         node.attr['fontsize'] = 13
@@ -183,8 +183,9 @@ def visualize_expl(dt, instance, expl):
             #instance path in dashed
             if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): 
                 edge.attr['style'] = 'dashed'
-            if edge.attr['label'] in expl:
-                edge.attr['color'] = 'blue'
+            for label in edge.attr['label'].split('\n'):
+                if label in expl:
+                    edge.attr['color'] = 'blue'
 
             edge.attr['fontsize'] = 10
             edge.attr['arrowsize'] = 0.8
diff --git a/pages/application/DecisionTree/utils/upload_tree.py b/pages/application/DecisionTree/utils/upload_tree.py
index e2babad..ea82f0a 100644
--- a/pages/application/DecisionTree/utils/upload_tree.py
+++ b/pages/application/DecisionTree/utils/upload_tree.py
@@ -82,7 +82,7 @@ def scores_tree(node, sample):
 
 #
 #==============================================================================
-def get_json_tree(model, tool, maxdepth=None, fname=None):
+def get_json_tree(model, tool, maxdepth=None):
     """
         returns the dtree in JSON format 
     """
@@ -116,19 +116,136 @@ class UploadedDecisionTree:
     """ A decision tree.
     This object provides a common interface to many different types of models.
     """
-    def __init__(self, model, tool, fname, maxdepth, feature_names=None, nb_classes = 0):
+    def __init__(self, model, tool, maxdepth, feature_names=None, nb_classes = 0):
         self.tool  = tool
         self.model = model
         self.tree  = None
         self.depth = None
         self.n_nodes = None
-        json_tree = get_json_tree(self.model, self.tool, maxdepth, fname)
+        json_tree = get_json_tree(self.model, self.tool, maxdepth)
         self.tree, self.n_nodes, self.depth = self.build_tree(json_tree, feature_names)
              
     def print_tree(self):
         print("DT model:")
         walk_tree(self.tree)
+
+
+    def dump(self, fvmap, filename=None, maxdepth=None, feat_names=None): 
+        """
+            save the dtree and data map in .dt/.map file   
+        """
+                     
+        def walk_tree(node, domains, internal, terminal):
+            """
+                extract internal (non-term) & terminal nodes
+            """
+            if (len(node.children) == 0): # leaf node
+                terminal.append((node.id, node.values))
+            else:
+                assert (node.children[0].id == node.left_node_id)
+                assert (node.children[1].id == node.right_node_id)
+                
+                f = f"f{node.feature}"
+                
+                if self.tool == "DL85":
+                    l,r = (1,0)
+                    internal.append((node.id, f, l, node.children[0].id))
+                    internal.append((node.id, f, r, node.children[1].id))
+                    
+                elif self.tool == "ITI":
+                    #l,r = (0,1)                    
+                    if len(fvmap[f]) > 2:
+                        n = 0
+                        for v in fvmap[f]:
+                            if (fvmap[f][v][2] == node.threshold) and \
+                                            (fvmap[f][v][1] == True):
+                                l = v
+                                n = n + 1
+                            if (fvmap[f][v][2] == node.threshold) and \
+                                            (fvmap[f][v][1] == False):  
+                                r = v
+                                n = n + 1
+                                
+                        assert (n == 2)      
+                            
+                    elif (fvmap[f][0][2] == node.threshold):
+                        l,r = (0,1)
+                    else:
+                        assert (fvmap[f][1][2] == node.threshold)
+                        l,r = (1,0)
+      
+                    internal.append((node.id, f, l, node.children[0].id))
+                    internal.append((node.id, f, r, node.children[1].id))                            
+                        
+                elif self.tool == "IAI":
+                    left, right = [], []
+                    for p in fvmap[f]:
+                        if fvmap[f][p][1] == True:
+                            assert (fvmap[f][p][2] in node.split)
+                            if node.split[fvmap[f][p][2]]: 
+                                left.append(p)
+                            else:
+                                right.append(p)
+                    
+                    internal.extend([(node.id, f, l, node.children[0].id) for l in left]) 
+                    internal.extend([(node.id, f, r, node.children[1].id) for r in right])    
+                
+                elif self.tool == 'SKL':
+                    left, right = [], []
+                    for j in domains[f]:
+                        if np.float32(fvmap[f][j][2]) <= np.float32(node.threshold):
+                            left.append(j)
+                        else:
+                            right.append(j)
+
+                    internal.extend([(node.id, f, l, node.children[0].id) for l in left]) 
+                    internal.extend([(node.id, f, r, node.children[1].id) for r in right]) 
+                    
+                    dom0, dom1 = dict(), dict()
+                    dom0.update(domains)
+                    dom1.update(domains)
+                    dom0[f] = left
+                    dom1[f] = right                     
+                  
+                else:
+                    assert False, 'Unhandled model type: {0}'.format(self.tool)
+
+                
+                internal, terminal = walk_tree(node.children[0], dom0, internal, terminal)
+                internal, terminal = walk_tree(node.children[1], dom1, internal, terminal)
+                
+            return internal, terminal 
+        
+        domains = {f:[j for j in fvmap[f] if((fvmap[f][j][1]))] for f in fvmap}
+        internal, terminal = walk_tree(self.tree, domains, [], [])
         
+
+        dt = f"{self.n_nodes}\n{self.tree.id}\n"
+        dt += f"I {' '.join(dict.fromkeys([str(i) for i,_,_,_ in internal]))}\n"
+        dt +=f"T {' '.join([str(i) for i,_ in terminal ])}\n"
+        for i,c in terminal:
+            dt +=f"{i} T {c}\n"            
+        for i,f, j, n in internal: 
+            dt +=f"{i} {f} {j} {n}\n"
+
+        map = "Categorical\n"
+        map += f"{len(fvmap)}"
+        for f in fvmap:
+            for v in fvmap[f]:
+                if (fvmap[f][v][1] == True):
+                    map += f"\n{f} {v} ={fvmap[f][v][2]}"
+                if (fvmap[f][v][1] == False) and self.tool == "ITI":
+                    map += f"\n{f} {v} !={fvmap[f][v][2]}"
+    
+            
+        if feat_names is not None:
+            features_names_mapping = ''
+            for i,fid in enumerate(feat_names):
+                f=f'f{i}'
+                features_names_mapping += f'T:C,{fid}:{f},'+",".join([f'{fvmap[f][v][2]}:{v}' for v in fvmap[f] if(fvmap[f][v][1])])+'\n'
+
+        return dt, map, features_names_mapping    
+
     def convert_dt(self, feat_names):
         """
             save dtree in .dt format & generate dtree map from the tree
@@ -182,18 +299,14 @@ class UploadedDecisionTree:
                 map += f"\n{f} {j} <={t}"
             map += f"\n{f} {j+1} >{t}"  
 
-        features_names_mapping = {}
-        for i,fid in enumerate(feat_names):
-            f=f'f{i}'
-            if f in self.intvs:
-                features_names_mapping[f] = fid
-                #features_names_mapping += ",".join([f'{t}:{j}' for j,t in enumerate(self.intvs[f])])+'\n'
-                #thresholds = self.intvs[f][:-1]+[self.intvs[f][-2]]
-                #fp.write(",".join([f'{t}:{j}' for j,t in enumerate(thresholds)])+'\n')
+        if feat_names is not None:
+            features_names_mapping = ''
+            for i,fid in enumerate(feat_names):
+                f=f'f{i}'
+                if f in self.intvs:
+                    features_names_mapping += f'\n Categorical,{fid}:{f},'
+                    features_names_mapping += ",".join([f'{t}:{j}' for j,t in enumerate(self.intvs[f])])
 
-        print(dt)
-        print(map)
-        print(features_names_mapping)
         return dt, map, features_names_mapping    
     
     
diff --git a/pages/application/application.py b/pages/application/application.py
index a27efac..f5c7032 100644
--- a/pages/application/application.py
+++ b/pages/application/application.py
@@ -19,6 +19,7 @@ class Model():
         self.ml_model = ''
 
         self.pretrained_model = ''
+        self.model_dataset = ''
 
         self.instance = ''
 
@@ -33,9 +34,12 @@ class Model():
         self.component_class = self.dict_components[self.ml_model]
         self.component_class =  globals()[self.component_class]
 
-    def update_pretrained_model(self, pretrained_model_update, filename_model):
+    def update_pretrained_model(self, pretrained_model_update):
         self.pretrained_model = pretrained_model_update
-        self.component = self.component_class(self.pretrained_model, filename_model)
+
+    def update_pretrained_model_dataset(self, model_dataset):
+        self.model_dataset = model_dataset
+        self.component = self.component_class(self.pretrained_model, self.model_dataset)
 
     def update_instance(self, instance, enum, xtype, solver="g3"):
         self.instance = instance
@@ -65,6 +69,16 @@ class View():
                                     ),
                                     html.Div(id='pretrained_model_filename')])
 
+        self.model_dataset = html.Div([
+                                    dcc.Upload(        
+                                        id='model_dataset_choice',
+                                        children=html.Div([
+                                            'Drag and Drop or ',
+                                            html.A('Select File')
+                                        ]),
+                                        className="upload"
+                                    )])
+
         self.instance_upload = html.Div([
                                     dcc.Upload(        
                                         id='ml_instance_choice',
@@ -87,6 +101,9 @@ class View():
                                     html.Br(),
                                     self.pretrained_model_upload, 
                                     html.Hr(),
+                                    html.Label("Choose the pretrained model dataset : "),
+                                    self.model_dataset,
+                                    html.Hr(),
                                     html.Label("Choose the instance to explain : "),
                                     html.Br(),
                                     self.instance_upload,
diff --git a/utils.py b/utils.py
index 4562909..4b38288 100644
--- a/utils.py
+++ b/utils.py
@@ -21,6 +21,19 @@ def parse_contents_graph(contents, filename):
 
     return data
 
+def parse_contents_data(contents, filename):
+    content_type, content_string = contents.split(',')
+    decoded = base64.b64decode(content_string)
+    try:        
+        if '.csv' in filename:
+            data = decoded.decode('utf-8')
+    except Exception as e:
+        print(e)
+        return html.Div([
+            'There was an error processing this file.'
+        ])
+
+    return data
 
 def parse_contents_instance(contents, filename):
     content_type, content_string = contents.split(',')
-- 
GitLab