diff --git a/callbacks.py b/callbacks.py index 99694e4b0ba0509e1c7d98819128164bb297f825..4894d7592e2deee27063dc94348575d27ca0fbdd 100644 --- a/callbacks.py +++ b/callbacks.py @@ -1,3 +1,5 @@ +import ast + import dash from dash.dependencies import Input, Output, State import dash_bootstrap_components as dbc @@ -7,6 +9,7 @@ from utils import parse_contents_graph, parse_contents_instance, parse_contents_ from pages.application.RandomForest.utils import xrf from pages.application.RandomForest.utils.xrf import * + sys.modules['xrf'] = xrf @@ -25,6 +28,11 @@ def register_callbacks(page_home, page_course, page_application, app): color='warning', duration=10000, ), ]) + warning_selection_data = html.Div([dbc.Alert("You uploaded the model, now upload the data.", + is_open=True, + color='warning', + duration=10000, ), ]) + alert_network = html.Div([dbc.Alert("There was a problem while computing the graph, read the documentation. \ You might have forgotten to upload the data for Random Forest or you tried to upload an unknown format.", is_open=True, @@ -77,13 +85,13 @@ def register_callbacks(page_home, page_course, page_application, app): Input('ml_pretrained_model_choice', 'filename'), prevent_initial_call=True ) - def update_model_prertrained_name(value_ml_model, pretrained_model_filename): + def update_model_pretrained_name(value_ml_model, pretrained_model_filename): ctx = dash.callback_context if ctx.triggered: ihm_id = ctx.triggered[0]['prop_id'].split('.')[0] if ihm_id == 'ml_model_choice': return None - else : + else: return pretrained_model_filename @app.callback(Output('info_filename', 'children'), @@ -97,7 +105,7 @@ def register_callbacks(page_home, page_course, page_application, app): ihm_id = ctx.triggered[0]['prop_id'].split('.')[0] if ihm_id == 'ml_model_choice': return None - else : + else: return model_info_filename @app.callback(Output('instance_filename', 'children'), @@ -114,9 +122,22 @@ def register_callbacks(page_home, page_course, page_application, app): return None elif ihm_id == 'ml_pretrained_model_choice': return None - else : + else: return instance_filename + @app.callback( + Output('choice_info_div', 'hidden'), + Input('add_info_model_choice', 'on'), + prevent_initial_call=True + ) + def add_model_info(add_info_model_choice): + model_application = page_application.model + model_application.update_info_needed(add_info_model_choice) + if add_info_model_choice: + return False + else: + return True + @app.callback( Output('graph', 'children'), Output('explanation', 'children'), @@ -150,13 +171,16 @@ def register_callbacks(page_home, page_course, page_application, app): # Choice of information for the model - data, feature mapping ... elif ihm_id == 'model_info_choice': - try : + try: model_info = parse_contents_data(model_info, model_info_filename) - model_application.model_info = model_info + model_application.update_info(model_info) if model_application.ml_model is None: return warning_selection_model, None - model_application.update_pretrained_model_layout_with_info(model_info, model_info_filename) - return model_application.component.network, None + elif model_application.pretrained_model is not None: + model_application.update_pretrained_model_layout_with_info(model_info_filename) + return model_application.component.network, None + else: + return warning_selection_model, None except: return warning_selection_pretrained_model, None @@ -168,10 +192,11 @@ def register_callbacks(page_home, page_course, page_application, app): if not model_application.add_info: model_application.update_pretrained_model_layout() return model_application.component.network, None - else: - model_application.update_pretrained_model_layout_with_info(model_application.model_info, - model_info_filename) + elif model_application.model_info is not None: + model_application.update_pretrained_model_layout_with_info(model_info_filename) return model_application.component.network, None + else: + return warning_selection_data, None except: return alert_network, None @@ -184,7 +209,6 @@ def register_callbacks(page_home, page_course, page_application, app): except: return model_application.component.network, alert_explanation - # Choice of number of expls elif ihm_id == 'number_explanations': try: @@ -211,11 +235,11 @@ def register_callbacks(page_home, page_course, page_application, app): # Choice of AxP to draw elif ihm_id == 'expl_choice': - try : + try: model_application.update_expl(expl_choice) return model_application.component.network, model_application.component.explanation except: - return model_application.component.network, alert_explanation + return alert_network, model_application.component.explanation # Choice of CxP to draw elif ihm_id == 'cont_expl_choice': @@ -223,15 +247,15 @@ def register_callbacks(page_home, page_course, page_application, app): model_application.update_cont_expl(cont_expl_choice) return model_application.component.network, model_application.component.explanation except: - return model_application.component.network, alert_explanation + return alert_network, model_application.component.explanation # In the case of RandomForest, id of tree to choose to draw tree elif ihm_id == 'choice_tree': - try : + try: model_application.update_tree_to_plot(id_tree) return model_application.component.network, model_application.component.explanation except: - return model_application.component.network, alert_explanation + return alert_network, model_application.component.explanation @app.callback( Output('explanation', 'hidden'), @@ -245,49 +269,54 @@ def register_callbacks(page_home, page_course, page_application, app): else: return False + ########### RandomForest ########### + @app.callback( - Output('choice_info_div', 'hidden'), - Input('add_info_model_choice', 'on'), + Output('choosing_tree', 'hidden'), + Input('ml_model_choice', 'value'), + Input('graph', 'children'), prevent_initial_call=True ) - def add_model_info(add_info_model_choice): - model_application = page_application.model - model_application.update_info_needed(add_info_model_choice) - if add_info_model_choice: + def choose_tree_in_forest(ml_type, graph): + if ml_type == "RandomForest" and graph is not None: return False else: return True - ########### RandomForest ########### + ########### DecistionTree ########### + @app.callback( - Output('choosing_tree', 'hidden'), - Input('graph', 'children'), + Output('div_switcher_draw_expl', 'hidden'), + Input('ml_model_choice', 'value'), prevent_initial_call=True ) - def choose_tree_in_forest(graph): - if page_application.model.ml_model == "RandomForest" and graph is not None: - return False - else: + def show_switcher_draw(ml_type): + if ml_type != "DecisionTree": return True + else: + return False + + @app.callback( + Output('drawing_expl', 'on'), + Input('ml_model_choice', 'value'), + Input('ml_pretrained_model_choice', 'contents'), + Input('model_info_choice', 'contents'), + Input('ml_instance_choice', 'contents'), + prevent_initial_call=True + ) + def turn_switcher_draw_off(ml_type, model, data, instance): + return False - ########### DecistionTree ########### @app.callback( Output('interaction_graph', 'hidden'), Output('expl_choice', 'options'), Output('cont_expl_choice', 'options'), - Input('explanation', 'children'), - Input('explanation_type', 'value'), + Input('drawing_expl', 'on'), prevent_initial_call=True ) - def layout_buttons_navigate_expls(explanation, explanation_type): - if page_application.model.ml_model == "DecisionTree" and explanation is not None and len(explanation_type) > 0: - options_expls = {} - options_cont_expls = {} - model_application = page_application.model - for i in range(len(model_application.list_expls)): - options_expls[str(model_application.list_expls[i])] = model_application.list_expls[i] - for i in range(len(model_application.list_cont_expls)): - options_cont_expls[str(model_application.list_cont_expls[i])] = model_application.list_cont_expls[i] - return False, options_expls, options_cont_expls - else: + def switcher_drawing_options(bool_draw): + if not bool_draw: return True, {}, {} + else: + model_application = page_application.model + return False, model_application.options_expls, model_application.options_cont_expls diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py index 19c7904e8612dedbfa0cfe6fbd65c6df19d4837f..8087481ceb89fac6a515863ccf8252168c54d6af 100644 --- a/pages/application/DecisionTree/DecisionTreeComponent.py +++ b/pages/application/DecisionTree/DecisionTreeComponent.py @@ -150,6 +150,7 @@ class DecisionTreeComponent: return list_explanations_path, list_contrastive_explanations_path def draw_explanation(self, instance, expl): + instance = instance[0] instance = self.translate_instance(instance) dot_source = visualize_expl(self.dt, instance, expl) self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( @@ -159,6 +160,7 @@ class DecisionTreeComponent: "background-color": "transparent"})]) def draw_contrastive_explanation(self, instance, cont_expl): + instance = instance[0] instance = self.translate_instance(instance) dot_source = visualize_contrastive_expl(self.dt, instance, cont_expl) self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( diff --git a/pages/application/RandomForest/utils/xrf/xforest.py b/pages/application/RandomForest/utils/xrf/xforest.py index 749e3c42b768d8fcb1d1b455ff3b60953813ada9..ccdeadb2ae22e159b525e37d9fcfabf0f9fb3fb0 100644 --- a/pages/application/RandomForest/utils/xrf/xforest.py +++ b/pages/application/RandomForest/utils/xrf/xforest.py @@ -209,9 +209,9 @@ class XRF(object): if self.x.slv is not None: self.x.slv.delete() del self.x - del self.f + #del self.f self.f = None - del self.cls + #del self.cls self.cls = None def encode(self, inst): diff --git a/pages/application/application.py b/pages/application/application.py index f502d764335d5148953904c9c674725975a3503f..4ce38fb35b8df8b484673946bc378c1afd729ac3 100644 --- a/pages/application/application.py +++ b/pages/application/application.py @@ -3,7 +3,6 @@ import dash_bootstrap_components as dbc import dash_daq as daq from pages.application.DecisionTree.DecisionTreeComponent import DecisionTreeComponent -from pages.application.NaiveBayes.NaiveBayesComponent import NaiveBayesComponent from pages.application.RandomForest.RandomForestComponent import RandomForestComponent @@ -39,9 +38,8 @@ class Model: self.instance = None - self.list_expls = [] - self.list_cont_expls = [] - + self.options_expls = {} + self.options_cont_expls = {} self.expl = None self.cont_expl = None @@ -58,22 +56,36 @@ class Model: self.xtype = [list(self.xtypes.keys())[0]] #init all params + self.component = None self.pretrained_model = None self.model_info = None self.instance = None + self.options_expls = {} + self.options_cont_expls = {} + self.expl = None + self.cont_expl = None + def update_pretrained_model(self, pretrained_model_update): self.pretrained_model = pretrained_model_update + self.component = None self.instance = None + self.options_expls = {} + self.options_cont_expls = {} + self.expl = None + self.cont_expl = None + def update_info_needed(self, add_info): self.add_info = add_info + def update_info(self, model_info): + self.model_info = model_info + def update_pretrained_model_layout(self): self.component = self.component_class(self.pretrained_model) - def update_pretrained_model_layout_with_info(self, model_info, model_info_filename): - self.model_info = model_info + def update_pretrained_model_layout_with_info(self, model_info_filename): self.component = self.component_class(self.pretrained_model, info=self.model_info, type_info=model_info_filename) @@ -82,31 +94,56 @@ class Model: def update_instance(self, instance): self.instance = instance - self.list_expls, self.list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, + list_expls, list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) + self.options_expls = {} + self.options_cont_expls = {} + for i in range(len(list_expls)): + self.options_expls[str(list_expls[i])] = list_expls[i] + for i in range(len(list_cont_expls)): + self.options_cont_expls[str(list_cont_expls[i])] = list_cont_expls[i] def update_enum(self, enum): self.enum = enum - self.list_expls, self.list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, + list_expls, list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) + self.options_expls = {} + self.options_cont_expls = {} + for i in range(len(list_expls)): + self.options_expls[str(list_expls[i])] = list_expls[i] + for i in range(len(list_cont_expls)): + self.options_cont_expls[str(list_cont_expls[i])] = list_cont_expls[i] + def update_xtype(self, xtype): self.xtype = xtype - self.list_expls, self.list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, + list_expls, list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) + self.options_expls = {} + self.options_cont_expls = {} + for i in range(len(list_expls)): + self.options_expls[str(list_expls[i])] = list_expls[i] + for i in range(len(list_cont_expls)): + self.options_cont_expls[str(list_cont_expls[i])] = list_cont_expls[i] def update_solver(self, solver): self.solver = solver - self.list_expls, self.list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, + list_expls, list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) + self.options_expls = {} + self.options_cont_expls = {} + for i in range(len(list_expls)): + self.options_expls[str(list_expls[i])] = list_expls[i] + for i in range(len(list_cont_expls)): + self.options_cont_expls[str(list_cont_expls[i])] = list_cont_expls[i] def update_expl(self, expl): self.expl = expl - self.component.draw_explanation(self.instance, expl) + self.component.draw_explanation(self.instance, self.expl) def update_cont_expl(self, cont_expl): - self.expl = cont_expl - self.component.draw_contrastive_explanation(self.instance, cont_expl) + self.cont_expl = cont_expl + self.component.draw_contrastive_explanation(self.instance, self.cont_expl) class View: @@ -213,14 +250,19 @@ class View: self.solver ], className="sidebar")]) + self.switch = html.Div(id="div_switcher_draw_expl", hidden=True, + children=[html.Label("Draw explanations ?"), + html.Br(), + daq.BooleanSwitch(id='drawing_expl', on=False, color="#FFFFFF", )]) + self.expl_choice = html.Div(id="interaction_graph", hidden=True, children=[html.H5("Navigate through the explanations and plot them on the tree : "), - html.Div(children=[dcc.Dropdown(self.model.list_expls, + html.Div(children=[dcc.Dropdown(self.model.options_expls, id='expl_choice', className="dropdown")]), html.H5( "Navigate through the contrastive explanations and plot them on the tree : "), - html.Div(children=[dcc.Dropdown(self.model.list_cont_expls, + html.Div(children=[dcc.Dropdown(self.model.options_cont_expls, id='cont_expl_choice', className="dropdown")])]) @@ -233,6 +275,7 @@ class View: self.layout = dbc.Row([dbc.Col([self.sidebar], width=3, class_name="sidebar"), dbc.Col([dbc.Row(id="graph", children=[]), + dbc.Row(self.switch), dbc.Row(self.expl_choice), dbc.Row(self.tree_to_plot)], width=5, class_name="column_graph"),