From b7475710ca3a56e0e96557c911a80e69eb6e00a1 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Mon, 28 Mar 2022 17:45:01 +0200 Subject: [PATCH] loading with or without dataset working, need treatment of fdf --- assets/header.css | 2 +- callbacks.py | 47 +++++--- .../DecisionTree/DecisionTreeComponent.py | 103 +++++++++++++++--- pages/application/DecisionTree/utils/data.py | 2 +- pages/application/DecisionTree/utils/dtree.py | 12 +- .../DecisionTree/utils/upload_tree.py | 19 ++-- pages/application/application.py | 62 ++++++++--- utils.py | 8 +- 8 files changed, 190 insertions(+), 65 deletions(-) diff --git a/assets/header.css b/assets/header.css index b7dcabb..8da0dd9 100644 --- a/assets/header.css +++ b/assets/header.css @@ -62,7 +62,7 @@ div.sidebar.col-3 { .sidebar .sidebar-dropdown{ width: 100%; - height: 30px; + height: 40px; line-height: 30px; border-width: 1px; border-radius: 5px; diff --git a/callbacks.py b/callbacks.py index 0427370..04930d6 100644 --- a/callbacks.py +++ b/callbacks.py @@ -31,14 +31,15 @@ def register_callbacks(page_home, page_course, page_application, app): @app.callback( Output('pretrained_model_filename', 'children'), + Output('info_filename', 'children'), Output('instance_filename', 'children'), Output('graph', 'children'), Output('explanation', 'children'), Input('ml_model_choice', 'value'), Input('ml_pretrained_model_choice', 'contents'), State('ml_pretrained_model_choice', 'filename'), - Input('model_dataset_choice', 'contents'), - State('model_dataset_choice', 'filename'), + Input('model_info_choice', 'contents'), + State('model_info_choice', 'filename'), Input('ml_instance_choice', 'contents'), State('ml_instance_choice', 'filename'), Input('number_explanations', 'value'), @@ -47,62 +48,67 @@ def register_callbacks(page_home, page_course, page_application, app): Input('expl_choice', 'value'), prevent_initial_call=True ) - def update_ml_type(value_ml_model, pretrained_model_contents, pretrained_model_filename, model_dataset, model_dataset_filename, instance_contents, instance_filename, enum, xtype, solver, expl_choice): + def update_ml_type(value_ml_model, pretrained_model_contents, pretrained_model_filename, model_info, model_info_filename, \ + instance_contents, instance_filename, enum, xtype, solver, expl_choice): ctx = dash.callback_context if ctx.triggered: ihm_id = ctx.triggered[0]['prop_id'].split('.')[0] model_application = page_application.model if ihm_id == 'ml_model_choice' : model_application.update_ml_model(value_ml_model) - return None, None, None, None + return None, None, None, None, None elif ihm_id == 'ml_pretrained_model_choice': if model_application.ml_model is None : raise PreventUpdate graph = parse_contents_graph(pretrained_model_contents, pretrained_model_filename) model_application.update_pretrained_model(graph) - return pretrained_model_filename, None, None, None + if not model_application.add_info : + model_application.update_pretrained_model_layout() + return pretrained_model_filename, None, None, model_application.component.network, None + else : + return pretrained_model_filename, None, None, None, None - elif ihm_id == 'model_dataset_choice': + elif ihm_id == 'model_info_choice': if model_application.ml_model is None : raise PreventUpdate - model_dataset = parse_contents_data(model_dataset, model_dataset_filename) - model_application.update_pretrained_model_dataset(model_dataset) - return pretrained_model_filename, None, model_application.component.network, None + model_info = parse_contents_data(model_info, model_info_filename) + model_application.update_pretrained_model_layout_with_info(model_info, model_info_filename) + return pretrained_model_filename, model_info_filename, None, model_application.component.network, None elif ihm_id == 'ml_instance_choice' : if model_application.ml_model is None or model_application.pretrained_model is None : raise PreventUpdate instance = parse_contents_instance(instance_contents, instance_filename) model_application.update_instance(instance, enum, xtype) - return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation + return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation elif ihm_id == 'number_explanations' : if model_application.ml_model is None or model_application.pretrained_model is None or model_application.instance is None: raise PreventUpdate instance = parse_contents_instance(model_application.instance, instance_filename) model_application.update_instance(instance, enum, xtype) - return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation + return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation elif ihm_id == 'explanation_type' : if model_application.ml_model is None or model_application.pretrained_model is None or model_application.instance is None: raise PreventUpdate instance = parse_contents_instance(model_application.instance, instance_filename) model_application.update_instance(instance, enum, xtype) - return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation + return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation elif ihm_id == 'solver_sat' : if model_application.ml_model is None or model_application.pretrained_model is None or model_application.instance is None: raise PreventUpdate instance = parse_contents_instance(model_application.instance, instance_filename) model_application.update_instance(instance, enum, xtype, solver=solver) - return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation + return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation elif ihm_id == 'expl_choice' : if instance_contents is None : raise PreventUpdate model_application.update_expl(expl_choice) - return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation + return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation @app.callback( @@ -125,3 +131,16 @@ def register_callbacks(page_home, page_course, page_application, app): for i in range (len(model_application.list_expls)): options[str(model_application.list_expls[i])] = model_application.list_expls[i] return False, False, False, options + + @app.callback( + Output('model_info_choice', 'disabled'), + Input('add_info_model_choice', 'value'), + prevent_initial_call=True + ) + def add_model_info(add_info_model_choice): + model_application = page_application.model + model_application.update_info_needed(add_info_model_choice) + if add_info_model_choice==1: + return False + else : + return True diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py index 020a956..a1e2363 100644 --- a/pages/application/DecisionTree/DecisionTreeComponent.py +++ b/pages/application/DecisionTree/DecisionTreeComponent.py @@ -15,34 +15,105 @@ from pages.application.DecisionTree.utils.dtviz import (visualize, class DecisionTreeComponent(): - def __init__(self, tree, dataset): + def __init__(self, tree, info=None, type_info=''): - data = Data(dataset) - fvmap = data.mapping_features() - - try: - feature_names = tree.feature_names_in_ - except: - print("You did not dump the model with the features names") - feature_names = [f'f{i}' for i in range(tree.n_features_in_)] - - self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_) - self.dt_format, self.map, features_names_mapping = self.uploaded_dt.dump(fvmap, feat_names=feature_names) - self.dt = DecisionTree(from_dt=self.dt_format, mapfile = self.map) + if info is not None and '.csv' in type_info: + self.categorical = True + data = Data(info) + fvmap = data.mapping_features() + feature_names = data.names[:-1] + + self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_) + self.dt_format, self.map, features_names_mapping = self.uploaded_dt.dump(fvmap, feat_names=feature_names) + self.mapping_instance = self.create_fvmap_inverse_with_info(features_names_mapping) + + + elif info is not None and '.txt' in type_info : + self.categorical = True + fvmap = {} + + self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_) + self.dt_format, self.map, features_names_mapping = self.uploaded_dt.dump(fvmap, feat_names=feature_names) + self.mapping_instance = self.create_fvmap_inverse_with_info(features_names_mapping) + + else : + self.categorical = False + try: + feature_names = tree.feature_names_in_ + except: + feature_names = [f'f{i}' for i in range(tree.n_features_in_)] + + self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_) + self.dt_format, self.map, features_names_mapping = self.uploaded_dt.convert_dt(feat_names=feature_names) + self.mapping_instance = self.create_fvmap_inverse_threashold(features_names_mapping) + + self.dt = DecisionTree(from_dt=self.dt_format, mapfile = self.map) dot_source = visualize(self.dt) self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz(dot_source=dot_source, style = {"width": "60%", "height": "90%", "background-color": "transparent"}))] self.explanation = [] + + def create_fvmap_inverse_with_info(self, features_names_mapping) : + mapping_instance = {} + for feat in features_names_mapping : + feat_dic = {} + feature_description = feat.split(',') + name_feat, id_feat = feature_description[1].split(':') + + for mapping in feature_description[2:]: + real_value, mapped_value = mapping.split(':') + feat_dic[np.float32(real_value)] = int(mapped_value) + mapping_instance[name_feat] = feat_dic + + return mapping_instance + + + def create_fvmap_inverse_threashold(self, features_names_mapping) : + mapping_instance = {} + for feat in features_names_mapping : + feature_description = feat.split(',') + name_feat, id_feat = feature_description[1].split(':') + mapping_instance[name_feat] = float(feature_description[2].split(':')[0]) + + return mapping_instance + + + def translate_instance(self, instance): + def translate_instance_categorical(instance): + instance_translated = [] + for feat, real_value in instance : + instance_translated.append((feat, self.mapping_instance[feat][real_value])) + return instance_translated + + def translate_instance_threasholds(instance): + instance_translated = [] + for feat, real_value in instance : + try: + if real_value <= self.mapping_instance[feat]: + instance_translated.append((feat, 0)) + else : + instance_translated.append((feat, 1)) + except: + instance_translated.append((feat, real_value)) + return instance_translated + + if self.categorical : + return translate_instance_categorical(instance) + else : + return translate_instance_threasholds(instance) + + def update_with_explicability(self, instance, enum, xtype, solver) : + instance_translated = self.translate_instance(instance) self.explanation = [] list_explanations_path=[] - explanation = self.dt.explain(instance, enum=enum, xtype = xtype, solver=solver) + explanation = self.dt.explain(instance_translated, enum=enum, xtype = xtype, solver=solver) - dot_source = visualize_instance(self.dt, instance) + dot_source = visualize_instance(self.dt, instance_translated) self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz( dot_source=dot_source, style = {"width": "50%", "height": "80%", @@ -67,7 +138,7 @@ class DecisionTreeComponent(): return list_explanations_path def draw_explanation(self, instance, expl) : - print(expl) + instance = self.translate_instance(instance) dot_source = visualize_expl(self.dt, instance, expl) self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz( dot_source=dot_source, diff --git a/pages/application/DecisionTree/utils/data.py b/pages/application/DecisionTree/utils/data.py index e719b59..d23ddb6 100644 --- a/pages/application/DecisionTree/utils/data.py +++ b/pages/application/DecisionTree/utils/data.py @@ -99,4 +99,4 @@ class Data(object): for j, v in enumerate(sorted(self.feats[i])): fvmap[f'f{i}'][j+m] = (self.names[i], False, v) - return fvmap \ No newline at end of file + return fvmap diff --git a/pages/application/DecisionTree/utils/dtree.py b/pages/application/DecisionTree/utils/dtree.py index bbd6046..1e7e270 100644 --- a/pages/application/DecisionTree/utils/dtree.py +++ b/pages/application/DecisionTree/utils/dtree.py @@ -332,15 +332,15 @@ class DecisionTree(): """ Compute a given number of explanations. """ - - self.feids = {f[0]: i for i, f in enumerate(inst)} - - path, term, depth = self.execute(inst, pathlits) - #contaiins all the elements for explanation explanation_dic = {} #instance plotting - explanation_dic["Instance : "] = str([self.fvmap[inst[i]] for i in range (len(inst))]) + explanation_dic["Instance : "] = str([str(inst[i]) for i in range (len(inst))]) + + self.feids = {f'f{i}': i for i, f in enumerate(inst)} + inst = [(f'f{i}', int(inst[i][1])) for i,f in enumerate(inst)] + + path, term, depth = self.execute(inst, pathlits) #decision path decision_path_str = 'IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[inst[self.feids[self.nodes[n].feat]]] for n in path]), term) diff --git a/pages/application/DecisionTree/utils/upload_tree.py b/pages/application/DecisionTree/utils/upload_tree.py index ea82f0a..b3e1297 100644 --- a/pages/application/DecisionTree/utils/upload_tree.py +++ b/pages/application/DecisionTree/utils/upload_tree.py @@ -239,10 +239,11 @@ class UploadedDecisionTree: if feat_names is not None: - features_names_mapping = '' + features_names_mapping = [] for i,fid in enumerate(feat_names): - f=f'f{i}' - features_names_mapping += f'T:C,{fid}:{f},'+",".join([f'{fvmap[f][v][2]}:{v}' for v in fvmap[f] if(fvmap[f][v][1])])+'\n' + feat=f'f{i}' + f = f'T:C,{fid}:{feat},'+",".join([f'{fvmap[feat][v][2]}:{v}' for v in fvmap[feat] if(fvmap[feat][v][1])]) + features_names_mapping.append(f) return dt, map, features_names_mapping @@ -299,13 +300,15 @@ class UploadedDecisionTree: map += f"\n{f} {j} <={t}" map += f"\n{f} {j+1} >{t}" + if feat_names is not None: - features_names_mapping = '' + features_names_mapping = [] for i,fid in enumerate(feat_names): - f=f'f{i}' - if f in self.intvs: - features_names_mapping += f'\n Categorical,{fid}:{f},' - features_names_mapping += ",".join([f'{t}:{j}' for j,t in enumerate(self.intvs[f])]) + feat=f'f{i}' + if feat in self.intvs: + f = f'T:O,{fid}:{feat},' + f += ",".join([f'{t}:{j}' for j,t in enumerate(self.intvs[feat])]) + features_names_mapping.append(f) return dt, map, features_names_mapping diff --git a/pages/application/application.py b/pages/application/application.py index f5c7032..925a1b7 100644 --- a/pages/application/application.py +++ b/pages/application/application.py @@ -19,7 +19,9 @@ class Model(): self.ml_model = '' self.pretrained_model = '' - self.model_dataset = '' + + self.add_info = False + self.model_info = '' self.instance = '' @@ -37,9 +39,15 @@ class Model(): def update_pretrained_model(self, pretrained_model_update): self.pretrained_model = pretrained_model_update - def update_pretrained_model_dataset(self, model_dataset): - self.model_dataset = model_dataset - self.component = self.component_class(self.pretrained_model, self.model_dataset) + def update_info_needed(self, add_info): + self.add_info = add_info + + def update_pretrained_model_layout(self): + self.component = self.component_class(self.pretrained_model) + + def update_pretrained_model_layout_with_info(self, model_info, model_info_filename): + self.model_info = model_info + self.component = self.component_class(self.pretrained_model, self.model_info, model_info_filename) def update_instance(self, instance, enum, xtype, solver="g3"): self.instance = instance @@ -57,7 +65,13 @@ class View(): self.ml_menu_models = dcc.Dropdown(self.model.ml_models, id='ml_model_choice', className="sidebar-dropdown") - + + self.ml_library_used = dcc.Dropdown(options = [{'label': 'Scikit-learn ', 'value': "SKL"}, + {'label': 'ITI', 'value': "ITI"}, + {'label': 'IAI', 'value': "IAI"}], + id='ml_library_choice', + className="sidebar-dropdown") + self.pretrained_model_upload = html.Div([ dcc.Upload( id='ml_pretrained_model_choice', @@ -69,15 +83,22 @@ class View(): ), html.Div(id='pretrained_model_filename')]) - self.model_dataset = html.Div([ + self.add_model_info_choice = dcc.RadioItems(id="add_info_model_choice", + options = [{'label': 'Yes ', 'value': 1}, + {'label': 'No', 'value': 0}], + value=0, className="sidebar-dropdown") + + self.model_info = html.Div([ dcc.Upload( - id='model_dataset_choice', + id='model_info_choice', + disabled=True, children=html.Div([ 'Drag and Drop or ', html.A('Select File') ]), className="upload" - )]) + ), + html.Div(id='info_filename')]) self.instance_upload = html.Div([ dcc.Upload( @@ -97,17 +118,27 @@ class View(): html.Br(), self.ml_menu_models, html.Hr(), + html.Label("Choose the Machine Learning library used :"), + html.Br(), + self.ml_library_used, + html.Hr(), html.Label("Choose the pretrained model : "), html.Br(), self.pretrained_model_upload, html.Hr(), - html.Label("Choose the pretrained model dataset : "), - self.model_dataset, + html.Label("Do you wish to upload more info for your model ? : "), + html.Br(), + self.add_model_info_choice, html.Hr(), - html.Label("Choose the instance to explain : "), + html.Label("Choose the pretrained model dataset (csv) or feature definition file (txt): "), html.Br(), - self.instance_upload, + self.model_info, html.Hr(), + html.Label("Choose the instance to explain : "), + html.Br(), + self.instance_upload], className="sidebar"), + dcc.Tab(label='Advanced Parameters', children = [ + html.Br(), html.Label("Choose the number of explanations : "), html.Br(), dcc.Input( @@ -124,17 +155,14 @@ class View(): options={'AXp' : "Abductive Explanation", 'CXp': "Contrastive explanation"}, value = ['AXp', 'CXp'], className="sidebar-dropdown", - inline=True)], className="sidebar"), - dcc.Tab(label='Advanced Parameters', children = [ + inline=True), html.Hr(), html.Label("Choose the SAT solver : "), html.Br(), dcc.Dropdown(['g3', 'g4', 'lgl', 'mcb', 'mcm', 'mpl', 'm22', 'mc', 'mgh'], 'g3', id='solver_sat') - ], className="sidebar") - ]) + ], className="sidebar")]) - self.expl_choice = dcc.Dropdown(self.model.list_expls, id='expl_choice', className="dropdown") diff --git a/utils.py b/utils.py index 4b38288..b342910 100644 --- a/utils.py +++ b/utils.py @@ -27,6 +27,8 @@ def parse_contents_data(contents, filename): try: if '.csv' in filename: data = decoded.decode('utf-8') + if '.txt' in filename: + data = decoded.decode('utf-8') except Exception as e: print(e) return html.Div([ @@ -39,9 +41,11 @@ def parse_contents_instance(contents, filename): content_type, content_string = contents.split(',') decoded = base64.b64decode(content_string) try: - if 'csv' in filename: + if '.csv' in filename: data = decoded.decode('utf-8') - elif 'txt' in filename: + elif '.txt' in filename: + data = decoded.decode('utf-8') + elif '.json' in filename: data = decoded.decode('utf-8') else : data = decoded.decode('utf-8') -- GitLab