From 09e7fa2e650abf166a641a9c12881d19e5475e50 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Tue, 14 Jun 2022 10:46:07 +0200 Subject: [PATCH] adding callbacks chnaged with buttons --- app.py | 37 +--- assets/header.css | 15 ++ callbacks.py | 335 +++++++++++++++++-------------- pages/application/application.py | 33 ++- requirements.txt | 5 +- 5 files changed, 217 insertions(+), 208 deletions(-) diff --git a/app.py b/app.py index 41087dc..d04a864 100644 --- a/app.py +++ b/app.py @@ -1,50 +1,17 @@ # Run this app with `python app.py` and # visit http://127.0.0.1:8050/ in your web browser. -import json - import dash import dash_bootstrap_components as dbc from dash import dcc, html - from callbacks import register_callbacks -from pages.application.application import Application, Model, View -from utils import extract_data app = dash.Dash(__name__, external_stylesheets=[dbc.themes.LUX], suppress_callback_exceptions=True, meta_tags=[{'name': 'viewport', 'content': 'width=device-width, initial-scale=1'}]) +app.config.suppress_callback_exceptions = True ################################################################################# ############################# Layouts ########################################### ################################################################################# -models_data = open('data_retriever.json') -data = json.load(models_data)["data"] - -# For home directory -welcome_message = html.Div(html.Iframe( - src=app.get_asset_url("welcome.html"), - style={"height": "1067px", "width": "100%"}, -)) -page_home = dbc.Row([welcome_message]) - -# For course directory -course_data_format = html.Div(html.Iframe( - src=app.get_asset_url("course_data_format.html"), - style={"height": "1067px", "width": "100%"}, -)) -course_decision_tree = html.Iframe( - src="assets/course_decision_tree.html", - style={"height": "1067px", "width": "100%"}, -) -main_course = dcc.Tabs(children=[ - dcc.Tab(label='Data format', children=[course_data_format]), - dcc.Tab(label='Course Decision Tree', children=[course_decision_tree])]) -page_course = dbc.Row([main_course]) - -# For the application -names_models, dict_components, dic_solvers, dic_xtypes = extract_data(data) -model_application = Model(names_models, dict_components, dic_solvers, dic_xtypes) -view_application = View(model_application) -page_application = Application(view_application) server = app.server @@ -67,7 +34,7 @@ app.layout = html.Div([ ################################################################################# ################################# Callback for the app ########################## ################################################################################# -register_callbacks(page_home, page_course, page_application, app) +register_callbacks(app) ################################################################################# ################################# Launching app ################################# diff --git a/assets/header.css b/assets/header.css index 9005062..97a186c 100644 --- a/assets/header.css +++ b/assets/header.css @@ -80,6 +80,21 @@ div.sidebar.col-3 { background-color: rgb(26,26,26); } +.sidebar .button{ + display: block; + margin-left: auto; + margin-right: auto; + width: 50%; + height: 30px; + border-width: 1px; + border-radius: 5px; + text-align: center; + align: center; + color:rgb(255, 255, 255); + font-weight: 400; + background-color: rgb(26,26,26); +} + .sidebar .has-value.Select--single > .Select-control .Select-value .Select-value-label, .has-value.is-pseudo-focused.Select--single > .sidebar .Select-control .Select-value .Select-value-label { color:rgb(255, 255, 255); } diff --git a/callbacks.py b/callbacks.py index 4894d75..47e3897 100644 --- a/callbacks.py +++ b/callbacks.py @@ -1,38 +1,51 @@ import ast - +import json import dash from dash.dependencies import Input, Output, State import dash_bootstrap_components as dbc from dash import Input, Output, State, html +from dash import dcc from utils import parse_contents_graph, parse_contents_instance, parse_contents_data +from pages.application.application import Model, View +from utils import extract_data from pages.application.RandomForest.utils import xrf from pages.application.RandomForest.utils.xrf import * - sys.modules['xrf'] = xrf -def register_callbacks(page_home, page_course, page_application, app): +def register_callbacks(app): page_list = ['home', 'course', 'application'] - ####### Creates alerts when callback fails ######### + # For home directory + page_home = dbc.Row([]) + # For course directory + main_course = dcc.Tabs(children=[ + dcc.Tab(label='Data format', children=[]), + dcc.Tab(label='Course Decision Tree', children=[])]) + page_course = dbc.Row([main_course]) + + models_data = open('data_retriever.json') + data = json.load(models_data)["data"] + # For the application + names_models, dict_components, dic_solvers, dic_xtypes = extract_data(data) + model_application = Model(names_models, dict_components, dic_solvers, dic_xtypes) + view_application = View(model_application) + ####### Creates alerts when callback fails ######### warning_selection_model = html.Div([dbc.Alert("You didn't choose a king of Machine Learning model first.", is_open=True, color='warning', duration=10000, ), ]) - warning_selection_pretrained_model = html.Div([dbc.Alert("You uploaded the data, now upload the pretrained model.", is_open=True, color='warning', duration=10000, ), ]) - warning_selection_data = html.Div([dbc.Alert("You uploaded the model, now upload the data.", is_open=True, color='warning', duration=10000, ), ]) - alert_network = html.Div([dbc.Alert("There was a problem while computing the graph, read the documentation. \ You might have forgotten to upload the data for Random Forest or you tried to upload an unknown format.", is_open=True, @@ -44,7 +57,17 @@ def register_callbacks(page_home, page_course, page_application, app): color='danger', duration=10000, ), ]) - alert_version_model = [] + reinit = html.Div([dbc.Alert( + "Reinitialization caused by changing model type or pkl or data or instance.", + is_open=True, + color='info', + duration=5000, ), ]) + + init_network = html.Div([dbc.Alert( + "Initialization.", + is_open=True, + color='info', + duration=5000, ), ]) ###################################################### @@ -55,7 +78,7 @@ def register_callbacks(page_home, page_course, page_application, app): if pathname == '/': return page_home if pathname == '/application': - return page_application.view.layout + return view_application.layout if pathname == '/course': return page_course @@ -67,6 +90,7 @@ def register_callbacks(page_home, page_course, page_application, app): active_link = ([pathname == f'/{i}' for i in page_list]) return active_link[0], active_link[1], active_link[2] + # region mltype @app.callback(Output('solver_sat', 'options'), Output('solver_sat', 'value'), Output('explanation_type', 'options'), @@ -74,211 +98,217 @@ def register_callbacks(page_home, page_course, page_application, app): Input('ml_model_choice', 'value'), prevent_initial_call=True ) - def update_ml_type_options(value_ml_model): - model_application = page_application.model - model_application.update_ml_model(value_ml_model) + def update_ml_type_options(ml_type): + model_application.update_ml_model(ml_type) return model_application.solvers, model_application.solvers[0], model_application.xtypes, [ list(model_application.xtypes.keys())[0]] - @app.callback(Output('pretrained_model_filename', 'children'), - Input('ml_model_choice', 'value'), - Input('ml_pretrained_model_choice', 'filename'), - prevent_initial_call=True - ) - def update_model_pretrained_name(value_ml_model, pretrained_model_filename): + # endregion + + # region pretrained model + @app.callback( + Output('pretrained_model_filename', 'children'), + Input('ml_model_choice', 'value'), + Input('ml_pretrained_model_choice', 'contents'), + State('ml_pretrained_model_choice', 'filename'), + prevent_initial_call=True) + def select_model(ml_type, model, filename): ctx = dash.callback_context if ctx.triggered: - ihm_id = ctx.triggered[0]['prop_id'].split('.')[0] + ihm_id = ctx.triggered_id if ihm_id == 'ml_model_choice': return None else: - return pretrained_model_filename + graph = parse_contents_graph(model, filename) + model_application.update_pretrained_model(graph) + return filename + + # endregion + + # region data + @app.callback( + Output('add_info_model_choice', 'on'), + Input('ml_model_choice', 'value'), + prevent_initial_call=True + ) + def delete_info(ml_type): + model_application.update_info_needed(False) + return False + + @app.callback( + Output('choice_info_div', 'hidden'), + Input('add_info_model_choice', 'on'), + prevent_initial_call=True + ) + def add_model_info(add_info_model_choice): + model_application.update_info_needed(add_info_model_choice) + if add_info_model_choice: + return False + else: + return True @app.callback(Output('info_filename', 'children'), + Output('intermediate-value-data', 'data'), Input('ml_model_choice', 'value'), - Input('model_info_choice', 'filename'), + Input('model_info_choice', 'contents'), + State('model_info_choice', 'filename'), prevent_initial_call=True ) - def update_model_info_filename(value_ml_model, model_info_filename): + def select_data(ml_type, data, filename): ctx = dash.callback_context if ctx.triggered: - ihm_id = ctx.triggered[0]['prop_id'].split('.')[0] + ihm_id = ctx.triggered_id if ihm_id == 'ml_model_choice': - return None + return None, None else: - return model_info_filename + model_info = parse_contents_data(data, filename) + model_application.update_info(model_info) + return filename, model_info + + # endregion + # region instance @app.callback(Output('instance_filename', 'children'), Input('ml_model_choice', 'value'), Input('ml_pretrained_model_choice', 'contents'), - Input('ml_instance_choice', 'filename'), + Input('ml_instance_choice', 'contents'), + State('ml_instance_choice', 'filename'), + Input('number_explanations', 'value'), + Input('explanation_type', 'value'), + Input('solver_sat', 'value'), prevent_initial_call=True ) - def update_instance_filename(value_ml_model, model_info_filename, instance_filename): + def select_instance(ml_type, model, instance, filename, enum, xtype, solver): ctx = dash.callback_context if ctx.triggered: - ihm_id = ctx.triggered[0]['prop_id'].split('.')[0] - if ihm_id == 'ml_model_choice': - return None - elif ihm_id == 'ml_pretrained_model_choice': + ihm_id = ctx.triggered_id + if ihm_id == 'ml_model_choice' or ihm_id == 'ml_pretrained_model_choice': return None else: - return instance_filename + # Choice of number of expls + if ihm_id == 'number_explanations': + model_application.update_enum(enum) + # Choice of AxP or CxP + elif ihm_id == 'explanation_type': + model_application.update_xtype(xtype) + # Choice of solver + elif ihm_id == 'solver_sat': + model_application.update_solver(solver) + else: + instance = parse_contents_instance(instance, filename) + model_application.update_instance(instance) + return filename - @app.callback( - Output('choice_info_div', 'hidden'), - Input('add_info_model_choice', 'on'), - prevent_initial_call=True - ) - def add_model_info(add_info_model_choice): - model_application = page_application.model - model_application.update_info_needed(add_info_model_choice) - if add_info_model_choice: - return False - else: - return True + # endregion + # region draw @app.callback( Output('graph', 'children'), - Output('explanation', 'children'), Input('ml_model_choice', 'value'), Input('ml_pretrained_model_choice', 'contents'), - State('ml_pretrained_model_choice', 'filename'), - Input('model_info_choice', 'contents'), + Input('submit-model', 'n_clicks'), State('model_info_choice', 'filename'), - Input('ml_instance_choice', 'contents'), - State('ml_instance_choice', 'filename'), - Input('number_explanations', 'value'), - Input('explanation_type', 'value'), - Input('solver_sat', 'value'), Input('expl_choice', 'value'), Input('cont_expl_choice', 'value'), Input('choice_tree', 'value'), - prevent_initial_call=True - ) - def update_ml_type(value_ml_model, pretrained_model_contents, pretrained_model_filename, model_info, - model_info_filename, instance_contents, instance_filename, enum, xtype, solver, expl_choice, - cont_expl_choice, id_tree): + prevent_initial_call=True) + def draw_model(ml_type, model, click, model_info_filename, expl_choice, cont_expl_choice, id_tree): ctx = dash.callback_context - if ctx.triggered: - ihm_id = ctx.triggered[0]['prop_id'].split('.')[0] - model_application = page_application.model - - # Choice of model - if ihm_id == 'ml_model_choice': - model_application.update_ml_model(value_ml_model) - return None, None - - # Choice of information for the model - data, feature mapping ... - elif ihm_id == 'model_info_choice': - try: - model_info = parse_contents_data(model_info, model_info_filename) - model_application.update_info(model_info) + try: + if ctx.triggered: + ihm_id = ctx.triggered_id + if ihm_id == "ml_model_choice" : + return reinit + elif ihm_id == "ml_pretrained_model_choice": + return init_network + elif ihm_id == "submit-model": if model_application.ml_model is None: - return warning_selection_model, None - elif model_application.pretrained_model is not None: - model_application.update_pretrained_model_layout_with_info(model_info_filename) - return model_application.component.network, None - else: - return warning_selection_model, None - except: - return warning_selection_pretrained_model, None - - # Choice of pkl pretrained model - elif ihm_id == 'ml_pretrained_model_choice': - try: - graph = parse_contents_graph(pretrained_model_contents, pretrained_model_filename) - model_application.update_pretrained_model(graph) - if not model_application.add_info: + return warning_selection_model + elif not model_application.add_info: model_application.update_pretrained_model_layout() - return model_application.component.network, None + return model_application.component.network elif model_application.model_info is not None: model_application.update_pretrained_model_layout_with_info(model_info_filename) - return model_application.component.network, None + return model_application.component.network else: - return warning_selection_data, None - except: - return alert_network, None - - # Choice of instance to explain - elif ihm_id == 'ml_instance_choice': - try: - instance = parse_contents_instance(instance_contents, instance_filename) - model_application.update_instance(instance) - return model_application.component.network, model_application.component.explanation - except: - return model_application.component.network, alert_explanation - - # Choice of number of expls - elif ihm_id == 'number_explanations': - try: - model_application.update_enum(enum) - return model_application.component.network, model_application.component.explanation - except: - return model_application.component.network, alert_explanation - - # Choice of AxP or CxP - elif ihm_id == 'explanation_type': - try: - model_application.update_xtype(xtype) - return model_application.component.network, model_application.component.explanation - except: - return model_application.component.network, alert_explanation - - # Choice of solver - elif ihm_id == 'solver_sat': - try: - model_application.update_solver(solver) - return model_application.component.network, model_application.component.explanation - except: - return model_application.component.network, alert_explanation + return warning_selection_data - # Choice of AxP to draw - elif ihm_id == 'expl_choice': - try: + elif ihm_id == 'expl_choice': model_application.update_expl(expl_choice) - return model_application.component.network, model_application.component.explanation - except: - return alert_network, model_application.component.explanation + return model_application.component.network - # Choice of CxP to draw - elif ihm_id == 'cont_expl_choice': - try: + # Choice of CxP to draw + elif ihm_id == 'cont_expl_choice': model_application.update_cont_expl(cont_expl_choice) - return model_application.component.network, model_application.component.explanation - except: - return alert_network, model_application.component.explanation + return model_application.component.network - # In the case of RandomForest, id of tree to choose to draw tree - elif ihm_id == 'choice_tree': - try: + # In the case of RandomForest, id of tree to choose to draw tree + elif ihm_id == 'choice_tree': model_application.update_tree_to_plot(id_tree) - return model_application.component.network, model_application.component.explanation - except: - return alert_network, model_application.component.explanation + return model_application.component.network + except: + return alert_network + + # endregion + + # region explanation @app.callback( Output('explanation', 'hidden'), - Input('explanation', 'children'), - Input('explanation_type', 'value'), + Input('submit-instance', 'n_clicks'), + Input('ml_model_choice', 'value'), + Input('ml_pretrained_model_choice', 'contents'), + Input('ml_instance_choice', 'contents'), prevent_initial_call=True ) - def show_explanation_window(explanation, explanation_type): - if explanation is None or len(explanation_type) == 0: - return True - else: - return False + def show_explanation_window(click, ml_type, model, instance): + ctx = dash.callback_context + if ctx.triggered: + ihm_id = ctx.triggered_id + if ihm_id == "ml_model_choice" or ihm_id == "ml_pretrained_model_choice" or ihm_id == "ml_instance_choice": + return True + elif ihm_id == "submit-instance": + return False + else : + return True + + @app.callback( + Output('explanation', 'children'), + Input('ml_model_choice', 'value'), + Input('ml_pretrained_model_choice', 'contents'), + Input('ml_instance_choice', 'contents'), + Input('submit-instance', 'n_clicks'), + prevent_initial_call=True) + def run_explanation(ml_type, model, instance, click): + ctx = dash.callback_context + if ctx.triggered: + ihm_id = ctx.triggered_id + if ihm_id == "ml_model_choice" or ihm_id == "ml_pretrained_model_choice" or ihm_id == "ml_instance_choice": + return reinit + elif ihm_id == "submit-instance": + try: + if model_application.ml_model is None: + return warning_selection_model + elif model_application.pretrained_model is None: + return warning_selection_pretrained_model + elif model_application.instance is not None: + return model_application.component.explanation + else: + return warning_selection_data + except: + return alert_explanation + + # endregion ########### RandomForest ########### @app.callback( Output('choosing_tree', 'hidden'), Input('ml_model_choice', 'value'), - Input('graph', 'children'), prevent_initial_call=True ) - def choose_tree_in_forest(ml_type, graph): - if ml_type == "RandomForest" and graph is not None: + def choose_tree_in_forest(ml_type): + if ml_type == "RandomForest": return False else: return True @@ -318,5 +348,4 @@ def register_callbacks(page_home, page_course, page_application, app): if not bool_draw: return True, {}, {} else: - model_application = page_application.model return False, model_application.options_expls, model_application.options_cont_expls diff --git a/pages/application/application.py b/pages/application/application.py index 4ce38fb..d952c3a 100644 --- a/pages/application/application.py +++ b/pages/application/application.py @@ -5,14 +5,6 @@ import dash_daq as daq from pages.application.DecisionTree.DecisionTreeComponent import DecisionTreeComponent from pages.application.RandomForest.RandomForestComponent import RandomForestComponent - - -class Application: - def __init__(self, view): - self.view = view - self.model = view.model - - class Model: def __init__(self, names_models, dict_components, dic_solvers, dic_xtypes): @@ -55,7 +47,7 @@ class Model: self.xtypes = self.dic_xtypes[self.ml_model] self.xtype = [list(self.xtypes.keys())[0]] - #init all params + # init all params self.component = None self.pretrained_model = None self.model_info = None @@ -77,6 +69,8 @@ class Model: self.cont_expl = None def update_info_needed(self, add_info): + if not add_info: + self.model_info = None self.add_info = add_info def update_info(self, model_info): @@ -95,7 +89,7 @@ class Model: def update_instance(self, instance): self.instance = instance list_expls, list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, - self.xtype, self.solver) + self.xtype, self.solver) self.options_expls = {} self.options_cont_expls = {} for i in range(len(list_expls)): @@ -106,7 +100,7 @@ class Model: def update_enum(self, enum): self.enum = enum list_expls, list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, - self.xtype, self.solver) + self.xtype, self.solver) self.options_expls = {} self.options_cont_expls = {} for i in range(len(list_expls)): @@ -114,11 +108,10 @@ class Model: for i in range(len(list_cont_expls)): self.options_cont_expls[str(list_cont_expls[i])] = list_cont_expls[i] - def update_xtype(self, xtype): self.xtype = xtype list_expls, list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, - self.xtype, self.solver) + self.xtype, self.solver) self.options_expls = {} self.options_cont_expls = {} for i in range(len(list_expls)): @@ -129,7 +122,7 @@ class Model: def update_solver(self, solver): self.solver = solver list_expls, list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, - self.xtype, self.solver) + self.xtype, self.solver) self.options_expls = {} self.options_cont_expls = {} for i in range(len(list_expls)): @@ -241,8 +234,12 @@ class View: self.ml_menu_models, self.add_model_info_choice, self.model_info, + dcc.Store(id='intermediate-value-data'), self.pretrained_model_upload, - self.instance_upload], className="sidebar"), + html.Button('Submit model', id='submit-model', n_clicks=0, className="button"), + html.Hr(), + self.instance_upload, + html.Button('Submit instance', id='submit-instance', n_clicks=0, className="button"),], className="sidebar"), dcc.Tab(label='Advanced Parameters', children=[ html.Br(), self.num_explanation, @@ -251,9 +248,9 @@ class View: ], className="sidebar")]) self.switch = html.Div(id="div_switcher_draw_expl", hidden=True, - children=[html.Label("Draw explanations ?"), - html.Br(), - daq.BooleanSwitch(id='drawing_expl', on=False, color="#FFFFFF", )]) + children=[html.Label("Draw explanations ?"), + html.Br(), + daq.BooleanSwitch(id='drawing_expl', on=False, color="#FFFFFF", )]) self.expl_choice = html.Div(id="interaction_graph", hidden=True, children=[html.H5("Navigate through the explanations and plot them on the tree : "), diff --git a/requirements.txt b/requirements.txt index 7ea8eec..165350a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ shap==0.40.0 anchor-exp==0.0.2.0 pysmt==0.9.0 anytree==2.8.0 -dash==2.1.0 +dash==2.4.0 dash_bootstrap_components==1.0.3 dash_daq==0.5.0 dash_interactive_graphviz==0.3.0 @@ -21,4 +21,5 @@ scikit_learn==1.0.2 six==1.16.0 gunicorn==20.1.0 Werkzeug==2.0.1 -pydot==1.4.2 \ No newline at end of file +pydot==1.4.2 +Flask-Caching==1.8 \ No newline at end of file -- GitLab