diff --git a/callbacks.py b/callbacks.py index 47e38973a543e62fc2cd672dd61cc9b0697c68e4..06b112f9a3e78088f141fe86f96c473977e239bd 100644 --- a/callbacks.py +++ b/callbacks.py @@ -14,6 +14,13 @@ from pages.application.RandomForest.utils import xrf from pages.application.RandomForest.utils.xrf import * sys.modules['xrf'] = xrf +from pages.application.DecisionTree.DecisionTreeComponent import DecisionTreeComponent +from pages.application.RandomForest.RandomForestComponent import RandomForestComponent + +""" +The callbacks are called whenever there is an interaction with the interface +""" + def register_callbacks(app): page_list = ['home', 'course', 'application'] @@ -30,10 +37,8 @@ def register_callbacks(app): data = json.load(models_data)["data"] # For the application names_models, dict_components, dic_solvers, dic_xtypes = extract_data(data) - model_application = Model(names_models, dict_components, dic_solvers, dic_xtypes) - view_application = View(model_application) - ####### Creates alerts when callback fails ######### + # region alerts warning_selection_model = html.Div([dbc.Alert("You didn't choose a king of Machine Learning model first.", is_open=True, color='warning', @@ -69,6 +74,8 @@ def register_callbacks(app): color='info', duration=5000, ), ]) + # endregion + ###################################################### @app.callback( @@ -78,7 +85,8 @@ def register_callbacks(app): if pathname == '/': return page_home if pathname == '/application': - return view_application.layout + model_application = Model(names_models, dict_components, dic_solvers, dic_xtypes) + return View(model_application).layout if pathname == '/course': return page_course @@ -90,7 +98,7 @@ def register_callbacks(app): active_link = ([pathname == f'/{i}' for i in page_list]) return active_link[0], active_link[1], active_link[2] - # region mltype + # region ml type @app.callback(Output('solver_sat', 'options'), Output('solver_sat', 'value'), Output('explanation_type', 'options'), @@ -99,9 +107,14 @@ def register_callbacks(app): prevent_initial_call=True ) def update_ml_type_options(ml_type): - model_application.update_ml_model(ml_type) - return model_application.solvers, model_application.solvers[0], model_application.xtypes, [ - list(model_application.xtypes.keys())[0]] + if ml_type is not None: + solvers = dic_solvers[ml_type] + solver = solvers[0] + xtypes = dic_xtypes[ml_type] + xtype = [list(xtypes.keys())[0]] + return solvers, solver, xtypes, xtype + else : + return [], None, [], None # endregion @@ -109,18 +122,15 @@ def register_callbacks(app): @app.callback( Output('pretrained_model_filename', 'children'), Input('ml_model_choice', 'value'), - Input('ml_pretrained_model_choice', 'contents'), - State('ml_pretrained_model_choice', 'filename'), + Input('ml_pretrained_model_choice', 'filename'), prevent_initial_call=True) - def select_model(ml_type, model, filename): + def select_model(ml_type, filename): ctx = dash.callback_context if ctx.triggered: ihm_id = ctx.triggered_id if ihm_id == 'ml_model_choice': return None - else: - graph = parse_contents_graph(model, filename) - model_application.update_pretrained_model(graph) + else : return filename # endregion @@ -132,7 +142,6 @@ def register_callbacks(app): prevent_initial_call=True ) def delete_info(ml_type): - model_application.update_info_needed(False) return False @app.callback( @@ -141,7 +150,6 @@ def register_callbacks(app): prevent_initial_call=True ) def add_model_info(add_info_model_choice): - model_application.update_info_needed(add_info_model_choice) if add_info_model_choice: return False else: @@ -162,7 +170,6 @@ def register_callbacks(app): return None, None else: model_info = parse_contents_data(data, filename) - model_application.update_info(model_info) return filename, model_info # endregion @@ -173,30 +180,15 @@ def register_callbacks(app): Input('ml_pretrained_model_choice', 'contents'), Input('ml_instance_choice', 'contents'), State('ml_instance_choice', 'filename'), - Input('number_explanations', 'value'), - Input('explanation_type', 'value'), - Input('solver_sat', 'value'), prevent_initial_call=True ) - def select_instance(ml_type, model, instance, filename, enum, xtype, solver): + def select_instance(ml_type, model, instance, filename): ctx = dash.callback_context if ctx.triggered: ihm_id = ctx.triggered_id - if ihm_id == 'ml_model_choice' or ihm_id == 'ml_pretrained_model_choice': + if ihm_id == 'ml_model_choice' or ihm_id == "ml_pretrained_model_choice": return None - else: - # Choice of number of expls - if ihm_id == 'number_explanations': - model_application.update_enum(enum) - # Choice of AxP or CxP - elif ihm_id == 'explanation_type': - model_application.update_xtype(xtype) - # Choice of solver - elif ihm_id == 'solver_sat': - model_application.update_solver(solver) - else: - instance = parse_contents_instance(instance, filename) - model_application.update_instance(instance) + else : return filename # endregion @@ -204,50 +196,85 @@ def register_callbacks(app): # region draw @app.callback( Output('graph', 'children'), + Output('explanation', 'children'), + Output('expl_choice', 'options'), + Output('cont_expl_choice', 'options'), Input('ml_model_choice', 'value'), Input('ml_pretrained_model_choice', 'contents'), + State('ml_pretrained_model_choice', 'filename'), + State('add_info_model_choice', 'on'), + State('intermediate-value-data', 'data'), + Input('ml_instance_choice', 'contents'), + State('ml_instance_choice', 'filename'), Input('submit-model', 'n_clicks'), State('model_info_choice', 'filename'), Input('expl_choice', 'value'), Input('cont_expl_choice', 'value'), + Input('number_explanations', 'value'), + Input('explanation_type', 'value'), + Input('solver_sat', 'value'), Input('choice_tree', 'value'), prevent_initial_call=True) - def draw_model(ml_type, model, click, model_info_filename, expl_choice, cont_expl_choice, id_tree): + def draw_model(ml_type, pretrained_model, model_filename, need_data, data, instance, instance_filename, click, model_info_filename, + expl_choice, cont_expl_choice, enum, xtype, solver, id_tree): ctx = dash.callback_context - try: - if ctx.triggered: - ihm_id = ctx.triggered_id - if ihm_id == "ml_model_choice" : - return reinit - elif ihm_id == "ml_pretrained_model_choice": - return init_network - elif ihm_id == "submit-model": - if model_application.ml_model is None: - return warning_selection_model - elif not model_application.add_info: - model_application.update_pretrained_model_layout() - return model_application.component.network - elif model_application.model_info is not None: - model_application.update_pretrained_model_layout_with_info(model_info_filename) - return model_application.component.network - else: - return warning_selection_data + if ctx.triggered: + ihm_id = ctx.triggered_id - elif ihm_id == 'expl_choice': - model_application.update_expl(expl_choice) - return model_application.component.network + if ihm_id == "ml_model_choice" : + return reinit, None, {}, {} - # Choice of CxP to draw - elif ihm_id == 'cont_expl_choice': - model_application.update_cont_expl(cont_expl_choice) - return model_application.component.network + elif ihm_id == "ml_pretrained_model_choice": + return init_network, None, {}, {} + + else : + try: + pretrained_model = parse_contents_graph(pretrained_model, model_filename) + if ml_type is None: + return warning_selection_model, None, {}, {} + else: + component_class = dict_components[ml_type] + component_class = globals()[component_class] + if not need_data: + component = component_class(pretrained_model) + elif data is not None: + component = component_class(pretrained_model, info=data, type_info=model_info_filename) + else: + return warning_selection_data, None, {}, {} - # In the case of RandomForest, id of tree to choose to draw tree - elif ihm_id == 'choice_tree': - model_application.update_tree_to_plot(id_tree) - return model_application.component.network - except: - return alert_network + except: + return alert_network, None, {}, {} + + else : + if ihm_id == "submit-model": + return component.network, None, {}, {} + + # In the case of RandomForest, id of tree to choose to draw tree + elif ihm_id == 'choice_tree': + component.update_plotted_tree(id_tree) + return component.network, component.explanation, {}, {} + + elif instance is not None: + instance = parse_contents_instance(instance, instance_filename) + list_expls, list_cont_expls = component.update_with_explicability(instance, enum, xtype, solver) + options_expls = {} + options_cont_expls = {} + for i in range(len(list_expls)): + options_expls[str(list_expls[i])] = list_expls[i] + for i in range(len(list_cont_expls)): + options_cont_expls[str(list_cont_expls[i])] = list_cont_expls[i] + + if ihm_id == 'expl_choice': + component.draw_explanation(instance, expl_choice) + return component.network, component.explanation, options_expls, options_cont_expls + + # Choice of CxP to draw + elif ihm_id == 'cont_expl_choice': + component.draw_contrastive_explanation(instance, cont_expl_choice) + return component.network, component.explanation, options_expls, options_cont_expls + + else : + return component.network, component.explanation, options_expls, options_cont_expls # endregion @@ -255,52 +282,24 @@ def register_callbacks(app): @app.callback( Output('explanation', 'hidden'), - Input('submit-instance', 'n_clicks'), Input('ml_model_choice', 'value'), Input('ml_pretrained_model_choice', 'contents'), - Input('ml_instance_choice', 'contents'), + State('ml_instance_choice', 'contents'), + Input('submit-instance', 'n_clicks'), prevent_initial_call=True ) - def show_explanation_window(click, ml_type, model, instance): + def show_explanation_window(ml_type, model, instante, click): ctx = dash.callback_context if ctx.triggered: ihm_id = ctx.triggered_id if ihm_id == "ml_model_choice" or ihm_id == "ml_pretrained_model_choice" or ihm_id == "ml_instance_choice": return True - elif ihm_id == "submit-instance": - return False else : - return True - - @app.callback( - Output('explanation', 'children'), - Input('ml_model_choice', 'value'), - Input('ml_pretrained_model_choice', 'contents'), - Input('ml_instance_choice', 'contents'), - Input('submit-instance', 'n_clicks'), - prevent_initial_call=True) - def run_explanation(ml_type, model, instance, click): - ctx = dash.callback_context - if ctx.triggered: - ihm_id = ctx.triggered_id - if ihm_id == "ml_model_choice" or ihm_id == "ml_pretrained_model_choice" or ihm_id == "ml_instance_choice": - return reinit - elif ihm_id == "submit-instance": - try: - if model_application.ml_model is None: - return warning_selection_model - elif model_application.pretrained_model is None: - return warning_selection_pretrained_model - elif model_application.instance is not None: - return model_application.component.explanation - else: - return warning_selection_data - except: - return alert_explanation + return False # endregion - ########### RandomForest ########### + # region randomforest @app.callback( Output('choosing_tree', 'hidden'), @@ -308,12 +307,11 @@ def register_callbacks(app): prevent_initial_call=True ) def choose_tree_in_forest(ml_type): - if ml_type == "RandomForest": - return False - else: - return True + return bool(ml_type != "RandomForest") + + # endregion - ########### DecistionTree ########### + # region decisiontree @app.callback( Output('div_switcher_draw_expl', 'hidden'), @@ -321,10 +319,7 @@ def register_callbacks(app): prevent_initial_call=True ) def show_switcher_draw(ml_type): - if ml_type != "DecisionTree": - return True - else: - return False + return bool(ml_type != "DecisionTree") @app.callback( Output('drawing_expl', 'on'), @@ -339,13 +334,10 @@ def register_callbacks(app): @app.callback( Output('interaction_graph', 'hidden'), - Output('expl_choice', 'options'), - Output('cont_expl_choice', 'options'), Input('drawing_expl', 'on'), prevent_initial_call=True ) def switcher_drawing_options(bool_draw): - if not bool_draw: - return True, {}, {} - else: - return False, model_application.options_expls, model_application.options_cont_expls + return not bool_draw + + # endregion \ No newline at end of file diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py index 8087481ceb89fac6a515863ccf8252168c54d6af..81f00168f609cc3bd2261bb5c4de42ae84542e18 100644 --- a/pages/application/DecisionTree/DecisionTreeComponent.py +++ b/pages/application/DecisionTree/DecisionTreeComponent.py @@ -9,9 +9,11 @@ from pages.application.DecisionTree.utils.upload_tree import UploadedDecisionTre class DecisionTreeComponent: + """ The component for Decision Tree models""" def __init__(self, tree, info=None, type_info=''): + # creation of model if info is not None and '.csv' in type_info: self.categorical = True data = Data(info) @@ -52,10 +54,13 @@ class DecisionTreeComponent: self.mapping_instance = self.create_fvmap_inverse(features_names_mapping) self.dt = DecisionTree(from_dt=self.dt_format, mapfile=self.map, feature_names=feature_names) dot_source = visualize(self.dt) + + # visual self.network = html.Div( [dash_interactive_graphviz.DashInteractiveGraphviz(dot_source=dot_source, style={"width": "60%", "height": "90%", "background-color": "transparent"})]) + # init explanation self.explanation = [] def create_fvmap_inverse(self, instance): @@ -112,7 +117,13 @@ class DecisionTreeComponent: return translate_instance_threasholds(instance) def update_with_explicability(self, instance, enum, xtype, solver): - + r""" Called when an instance is upload or when you press the button "Submit for explanation" with advanced parameters. + Args: + instance : list - list of instance to explain + enum : number of explanation to calculate + xtype : kind of explanation + solver : solver + """ instance = instance[0] instance_translated = self.translate_instance(instance) @@ -139,7 +150,7 @@ class DecisionTreeComponent: list_explanations_path = explanation["List of path explanation(s)"] list_contrastive_explanations_path = explanation["List of path contrastive explanation(s)"] - # Create graph + # visual dot_source = visualize_instance(self.dt, instance_translated) self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( dot_source=dot_source, style={"width": "50%", @@ -150,8 +161,14 @@ class DecisionTreeComponent: return list_explanations_path, list_contrastive_explanations_path def draw_explanation(self, instance, expl): + r""" Called with the selection of an explanation to plot on the tree + Args: + instance : list - list of instance to explain + expl : the explanation path to draw on the tree + """ instance = instance[0] instance = self.translate_instance(instance) + # visual dot_source = visualize_expl(self.dt, instance, expl) self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( dot_source=dot_source, @@ -160,8 +177,14 @@ class DecisionTreeComponent: "background-color": "transparent"})]) def draw_contrastive_explanation(self, instance, cont_expl): + r""" Called with the selection of contrastive explanation to plot on the tree + Args: + instance : list - list of instance to explain + cont_expl : the contrastive explanation path to draw on the tree + """ instance = instance[0] instance = self.translate_instance(instance) + # visual dot_source = visualize_contrastive_expl(self.dt, instance, cont_expl) self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( dot_source=dot_source, diff --git a/pages/application/RandomForest/RandomForestComponent.py b/pages/application/RandomForest/RandomForestComponent.py index 644b17bce651b4ae65ba940eeade4c1d35d9e99e..38522ba30353ae06e6a54059f5dde35e587fecdd 100644 --- a/pages/application/RandomForest/RandomForestComponent.py +++ b/pages/application/RandomForest/RandomForestComponent.py @@ -6,16 +6,19 @@ from pages.application.RandomForest.utils.xrf.xforest import XRF, Dataset class RandomForestComponent: + """ The component for Random Forest models""" def __init__(self, model, info=None, type_info=''): # Conversion model self.data = Dataset(info) + # creation of model if info is not None and 'csv' in type_info: self.random_forest = XRF(model, self.data.feature_names, self.data.target_name) - # encoding here so not in the explanation + # encoding here so not in the explanation ? + # visual self.tree_to_plot = 0 dot_source = tree.export_graphviz(self.random_forest.cls.estimators()[self.tree_to_plot], feature_names=self.data.feature_names, class_names=list(map(lambda cl : str(cl), self.data.target_name)), @@ -25,10 +28,18 @@ class RandomForestComponent: "height": "80%", "background-color": "transparent"} )]) + + # init explanation self.explanation = [] def update_with_explicability(self, instances, enum_feats=None, xtype=None, solver=None): - + r""" Called when an instance is upload or when you press the button "Submit for explanation" with advanced parameters. + Args: + instances : list - list of instance to explain + enum_feats : ghost feature + xtype : ghost feature + solver : ghost feature + """ instances = [list(map(lambda feature: feature[1], instance)) for instance in instances] self.explanation = [] for instance in instances: @@ -59,10 +70,16 @@ class RandomForestComponent: return [], [] def update_plotted_tree(self, tree_to_plot): + r""" Called by a slider to choose which tree in a random forest to plot + Args: + tree_to_plot : int - the id of the tree to plot + """ + # visual self.tree_to_plot = tree_to_plot dot_source = tree.export_graphviz(self.random_forest.cls.estimators()[self.tree_to_plot], feature_names=self.data.feature_names, class_names=list(map(lambda cl : str(cl), self.data.target_name)), impurity=False, filled=False, rounded=True) + self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( dot_source=dot_source, style={"width": "50%", "height": "80%", diff --git a/pages/application/application.py b/pages/application/application.py index d952c3a24d8b22a3506e7fa065d98ba1112cb10e..c53d260ec6fba2e8d863efb3dc5e79325bff7bd5 100644 --- a/pages/application/application.py +++ b/pages/application/application.py @@ -2,10 +2,9 @@ from dash import dcc, html import dash_bootstrap_components as dbc import dash_daq as daq -from pages.application.DecisionTree.DecisionTreeComponent import DecisionTreeComponent -from pages.application.RandomForest.RandomForestComponent import RandomForestComponent class Model: + """ Initialisation of the components into the layout""" def __init__(self, names_models, dict_components, dic_solvers, dic_xtypes): self.dict_components = dict_components @@ -38,116 +37,16 @@ class Model: self.component_class = None self.component = None - def update_ml_model(self, ml_model_update): - self.ml_model = ml_model_update - self.component_class = self.dict_components[self.ml_model] - self.component_class = globals()[self.component_class] - self.solvers = self.dic_solvers[self.ml_model] - self.solver = self.solvers[0] - self.xtypes = self.dic_xtypes[self.ml_model] - self.xtype = [list(self.xtypes.keys())[0]] - - # init all params - self.component = None - self.pretrained_model = None - self.model_info = None - self.instance = None - - self.options_expls = {} - self.options_cont_expls = {} - self.expl = None - self.cont_expl = None - - def update_pretrained_model(self, pretrained_model_update): - self.pretrained_model = pretrained_model_update - self.component = None - self.instance = None - - self.options_expls = {} - self.options_cont_expls = {} - self.expl = None - self.cont_expl = None - - def update_info_needed(self, add_info): - if not add_info: - self.model_info = None - self.add_info = add_info - - def update_info(self, model_info): - self.model_info = model_info - - def update_pretrained_model_layout(self): - self.component = self.component_class(self.pretrained_model) - - def update_pretrained_model_layout_with_info(self, model_info_filename): - self.component = self.component_class(self.pretrained_model, info=self.model_info, - type_info=model_info_filename) - - def update_tree_to_plot(self, id_tree): - self.component.update_plotted_tree(id_tree) - - def update_instance(self, instance): - self.instance = instance - list_expls, list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, - self.xtype, self.solver) - self.options_expls = {} - self.options_cont_expls = {} - for i in range(len(list_expls)): - self.options_expls[str(list_expls[i])] = list_expls[i] - for i in range(len(list_cont_expls)): - self.options_cont_expls[str(list_cont_expls[i])] = list_cont_expls[i] - - def update_enum(self, enum): - self.enum = enum - list_expls, list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, - self.xtype, self.solver) - self.options_expls = {} - self.options_cont_expls = {} - for i in range(len(list_expls)): - self.options_expls[str(list_expls[i])] = list_expls[i] - for i in range(len(list_cont_expls)): - self.options_cont_expls[str(list_cont_expls[i])] = list_cont_expls[i] - - def update_xtype(self, xtype): - self.xtype = xtype - list_expls, list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, - self.xtype, self.solver) - self.options_expls = {} - self.options_cont_expls = {} - for i in range(len(list_expls)): - self.options_expls[str(list_expls[i])] = list_expls[i] - for i in range(len(list_cont_expls)): - self.options_cont_expls[str(list_cont_expls[i])] = list_cont_expls[i] - - def update_solver(self, solver): - self.solver = solver - list_expls, list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, - self.xtype, self.solver) - self.options_expls = {} - self.options_cont_expls = {} - for i in range(len(list_expls)): - self.options_expls[str(list_expls[i])] = list_expls[i] - for i in range(len(list_cont_expls)): - self.options_cont_expls[str(list_cont_expls[i])] = list_cont_expls[i] - - def update_expl(self, expl): - self.expl = expl - self.component.draw_explanation(self.instance, self.expl) - - def update_cont_expl(self, cont_expl): - self.cont_expl = cont_expl - self.component.draw_contrastive_explanation(self.instance, self.cont_expl) - class View: - + """ Creates the layout of the app""" def __init__(self, model): self.model = model self.ml_menu_models = html.Div([ - html.Br(), + html.P(), html.Label("Choose the Machine Learning algorithm :"), - html.Br(), + html.P(), dcc.Dropdown(self.model.ml_models, id='ml_model_choice', className="dropdown")]) @@ -155,7 +54,7 @@ class View: self.pretrained_model_upload = html.Div([ html.Hr(), html.Label("Choose the pretrained model : "), - html.Br(), + html.P(), dcc.Upload( id='ml_pretrained_model_choice', children=html.Div([ @@ -169,7 +68,7 @@ class View: self.add_model_info_choice = html.Div([ html.Hr(), html.Label("Upload more data model (only for model with categorical variables) ?"), - html.Br(), + html.P(), daq.BooleanSwitch(id='add_info_model_choice', on=False, color="#000000", )]) self.model_info = html.Div(id="choice_info_div", @@ -178,7 +77,7 @@ class View: html.Hr(), html.Label( "Choose the pretrained model dataset (csv) or feature definition file (txt): "), - html.Br(), + html.P(), dcc.Upload( id='model_info_choice', children=html.Div([ @@ -192,7 +91,7 @@ class View: self.instance_upload = html.Div([ html.Hr(), html.Label("Choose the instance to explain : "), - html.Br(), + html.P(), dcc.Upload( id='ml_instance_choice', children=html.Div([ @@ -205,7 +104,7 @@ class View: self.num_explanation = html.Div([ html.Label("Choose the number of explanations : "), - html.Br(), + html.P(), dcc.Input( id="number_explanations", value=1, @@ -216,7 +115,7 @@ class View: self.type_explanation = html.Div([ html.Label("Choose the kind of explanation : "), - html.Br(), + html.P(), dcc.Checklist( id="explanation_type", options=self.model.xtypes, @@ -225,7 +124,7 @@ class View: html.Hr()]) self.solver = html.Div([html.Label("Choose the SAT solver : "), - html.Br(), + html.P(), dcc.Dropdown(self.model.solvers, id='solver_sat')]) @@ -236,12 +135,17 @@ class View: self.model_info, dcc.Store(id='intermediate-value-data'), self.pretrained_model_upload, + html.P(), html.Button('Submit model', id='submit-model', n_clicks=0, className="button"), html.Hr(), self.instance_upload, - html.Button('Submit instance', id='submit-instance', n_clicks=0, className="button"),], className="sidebar"), + html.Hr(), + html.Label("Set Advanced parameters then submit :"), + html.P(), + html.Button('Submit for explanation', id='submit-instance', n_clicks=0, className="button"), ], + className="sidebar"), dcc.Tab(label='Advanced Parameters', children=[ - html.Br(), + html.P(), self.num_explanation, self.type_explanation, self.solver @@ -249,7 +153,7 @@ class View: self.switch = html.Div(id="div_switcher_draw_expl", hidden=True, children=[html.Label("Draw explanations ?"), - html.Br(), + html.P(), daq.BooleanSwitch(id='drawing_expl', on=False, color="#FFFFFF", )]) self.expl_choice = html.Div(id="interaction_graph", hidden=True,