Skip to content
Snippets Groups Projects
Commit 09e7fa2e authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

adding callbacks chnaged with buttons

parent 934d20c0
Branches
No related tags found
No related merge requests found
# Run this app with `python app.py` and
# visit http://127.0.0.1:8050/ in your web browser.
import json
import dash
import dash_bootstrap_components as dbc
from dash import dcc, html
from callbacks import register_callbacks
from pages.application.application import Application, Model, View
from utils import extract_data
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.LUX], suppress_callback_exceptions=True,
meta_tags=[{'name': 'viewport', 'content': 'width=device-width, initial-scale=1'}])
app.config.suppress_callback_exceptions = True
#################################################################################
############################# Layouts ###########################################
#################################################################################
models_data = open('data_retriever.json')
data = json.load(models_data)["data"]
# For home directory
welcome_message = html.Div(html.Iframe(
src=app.get_asset_url("welcome.html"),
style={"height": "1067px", "width": "100%"},
))
page_home = dbc.Row([welcome_message])
# For course directory
course_data_format = html.Div(html.Iframe(
src=app.get_asset_url("course_data_format.html"),
style={"height": "1067px", "width": "100%"},
))
course_decision_tree = html.Iframe(
src="assets/course_decision_tree.html",
style={"height": "1067px", "width": "100%"},
)
main_course = dcc.Tabs(children=[
dcc.Tab(label='Data format', children=[course_data_format]),
dcc.Tab(label='Course Decision Tree', children=[course_decision_tree])])
page_course = dbc.Row([main_course])
# For the application
names_models, dict_components, dic_solvers, dic_xtypes = extract_data(data)
model_application = Model(names_models, dict_components, dic_solvers, dic_xtypes)
view_application = View(model_application)
page_application = Application(view_application)
server = app.server
......@@ -67,7 +34,7 @@ app.layout = html.Div([
#################################################################################
################################# Callback for the app ##########################
#################################################################################
register_callbacks(page_home, page_course, page_application, app)
register_callbacks(app)
#################################################################################
################################# Launching app #################################
......
......@@ -80,6 +80,21 @@ div.sidebar.col-3 {
background-color: rgb(26,26,26);
}
.sidebar .button{
display: block;
margin-left: auto;
margin-right: auto;
width: 50%;
height: 30px;
border-width: 1px;
border-radius: 5px;
text-align: center;
align: center;
color:rgb(255, 255, 255);
font-weight: 400;
background-color: rgb(26,26,26);
}
.sidebar .has-value.Select--single > .Select-control .Select-value .Select-value-label, .has-value.is-pseudo-focused.Select--single > .sidebar .Select-control .Select-value .Select-value-label {
color:rgb(255, 255, 255);
}
......
import ast
import json
import dash
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc
from dash import Input, Output, State, html
from dash import dcc
from utils import parse_contents_graph, parse_contents_instance, parse_contents_data
from pages.application.application import Model, View
from utils import extract_data
from pages.application.RandomForest.utils import xrf
from pages.application.RandomForest.utils.xrf import *
sys.modules['xrf'] = xrf
def register_callbacks(page_home, page_course, page_application, app):
def register_callbacks(app):
page_list = ['home', 'course', 'application']
####### Creates alerts when callback fails #########
# For home directory
page_home = dbc.Row([])
# For course directory
main_course = dcc.Tabs(children=[
dcc.Tab(label='Data format', children=[]),
dcc.Tab(label='Course Decision Tree', children=[])])
page_course = dbc.Row([main_course])
models_data = open('data_retriever.json')
data = json.load(models_data)["data"]
# For the application
names_models, dict_components, dic_solvers, dic_xtypes = extract_data(data)
model_application = Model(names_models, dict_components, dic_solvers, dic_xtypes)
view_application = View(model_application)
####### Creates alerts when callback fails #########
warning_selection_model = html.Div([dbc.Alert("You didn't choose a king of Machine Learning model first.",
is_open=True,
color='warning',
duration=10000, ), ])
warning_selection_pretrained_model = html.Div([dbc.Alert("You uploaded the data, now upload the pretrained model.",
is_open=True,
color='warning',
duration=10000, ), ])
warning_selection_data = html.Div([dbc.Alert("You uploaded the model, now upload the data.",
is_open=True,
color='warning',
duration=10000, ), ])
alert_network = html.Div([dbc.Alert("There was a problem while computing the graph, read the documentation. \
You might have forgotten to upload the data for Random Forest or you tried to upload an unknown format.",
is_open=True,
......@@ -44,7 +57,17 @@ def register_callbacks(page_home, page_course, page_application, app):
color='danger',
duration=10000, ), ])
alert_version_model = []
reinit = html.Div([dbc.Alert(
"Reinitialization caused by changing model type or pkl or data or instance.",
is_open=True,
color='info',
duration=5000, ), ])
init_network = html.Div([dbc.Alert(
"Initialization.",
is_open=True,
color='info',
duration=5000, ), ])
######################################################
......@@ -55,7 +78,7 @@ def register_callbacks(page_home, page_course, page_application, app):
if pathname == '/':
return page_home
if pathname == '/application':
return page_application.view.layout
return view_application.layout
if pathname == '/course':
return page_course
......@@ -67,6 +90,7 @@ def register_callbacks(page_home, page_course, page_application, app):
active_link = ([pathname == f'/{i}' for i in page_list])
return active_link[0], active_link[1], active_link[2]
# region mltype
@app.callback(Output('solver_sat', 'options'),
Output('solver_sat', 'value'),
Output('explanation_type', 'options'),
......@@ -74,211 +98,217 @@ def register_callbacks(page_home, page_course, page_application, app):
Input('ml_model_choice', 'value'),
prevent_initial_call=True
)
def update_ml_type_options(value_ml_model):
model_application = page_application.model
model_application.update_ml_model(value_ml_model)
def update_ml_type_options(ml_type):
model_application.update_ml_model(ml_type)
return model_application.solvers, model_application.solvers[0], model_application.xtypes, [
list(model_application.xtypes.keys())[0]]
@app.callback(Output('pretrained_model_filename', 'children'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'filename'),
prevent_initial_call=True
)
def update_model_pretrained_name(value_ml_model, pretrained_model_filename):
# endregion
# region pretrained model
@app.callback(
Output('pretrained_model_filename', 'children'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
State('ml_pretrained_model_choice', 'filename'),
prevent_initial_call=True)
def select_model(ml_type, model, filename):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered[0]['prop_id'].split('.')[0]
ihm_id = ctx.triggered_id
if ihm_id == 'ml_model_choice':
return None
else:
return pretrained_model_filename
graph = parse_contents_graph(model, filename)
model_application.update_pretrained_model(graph)
return filename
# endregion
# region data
@app.callback(
Output('add_info_model_choice', 'on'),
Input('ml_model_choice', 'value'),
prevent_initial_call=True
)
def delete_info(ml_type):
model_application.update_info_needed(False)
return False
@app.callback(
Output('choice_info_div', 'hidden'),
Input('add_info_model_choice', 'on'),
prevent_initial_call=True
)
def add_model_info(add_info_model_choice):
model_application.update_info_needed(add_info_model_choice)
if add_info_model_choice:
return False
else:
return True
@app.callback(Output('info_filename', 'children'),
Output('intermediate-value-data', 'data'),
Input('ml_model_choice', 'value'),
Input('model_info_choice', 'filename'),
Input('model_info_choice', 'contents'),
State('model_info_choice', 'filename'),
prevent_initial_call=True
)
def update_model_info_filename(value_ml_model, model_info_filename):
def select_data(ml_type, data, filename):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered[0]['prop_id'].split('.')[0]
ihm_id = ctx.triggered_id
if ihm_id == 'ml_model_choice':
return None
return None, None
else:
return model_info_filename
model_info = parse_contents_data(data, filename)
model_application.update_info(model_info)
return filename, model_info
# endregion
# region instance
@app.callback(Output('instance_filename', 'children'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
Input('ml_instance_choice', 'filename'),
Input('ml_instance_choice', 'contents'),
State('ml_instance_choice', 'filename'),
Input('number_explanations', 'value'),
Input('explanation_type', 'value'),
Input('solver_sat', 'value'),
prevent_initial_call=True
)
def update_instance_filename(value_ml_model, model_info_filename, instance_filename):
def select_instance(ml_type, model, instance, filename, enum, xtype, solver):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered[0]['prop_id'].split('.')[0]
if ihm_id == 'ml_model_choice':
return None
elif ihm_id == 'ml_pretrained_model_choice':
ihm_id = ctx.triggered_id
if ihm_id == 'ml_model_choice' or ihm_id == 'ml_pretrained_model_choice':
return None
else:
return instance_filename
# Choice of number of expls
if ihm_id == 'number_explanations':
model_application.update_enum(enum)
# Choice of AxP or CxP
elif ihm_id == 'explanation_type':
model_application.update_xtype(xtype)
# Choice of solver
elif ihm_id == 'solver_sat':
model_application.update_solver(solver)
else:
instance = parse_contents_instance(instance, filename)
model_application.update_instance(instance)
return filename
@app.callback(
Output('choice_info_div', 'hidden'),
Input('add_info_model_choice', 'on'),
prevent_initial_call=True
)
def add_model_info(add_info_model_choice):
model_application = page_application.model
model_application.update_info_needed(add_info_model_choice)
if add_info_model_choice:
return False
else:
return True
# endregion
# region draw
@app.callback(
Output('graph', 'children'),
Output('explanation', 'children'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
State('ml_pretrained_model_choice', 'filename'),
Input('model_info_choice', 'contents'),
Input('submit-model', 'n_clicks'),
State('model_info_choice', 'filename'),
Input('ml_instance_choice', 'contents'),
State('ml_instance_choice', 'filename'),
Input('number_explanations', 'value'),
Input('explanation_type', 'value'),
Input('solver_sat', 'value'),
Input('expl_choice', 'value'),
Input('cont_expl_choice', 'value'),
Input('choice_tree', 'value'),
prevent_initial_call=True
)
def update_ml_type(value_ml_model, pretrained_model_contents, pretrained_model_filename, model_info,
model_info_filename, instance_contents, instance_filename, enum, xtype, solver, expl_choice,
cont_expl_choice, id_tree):
prevent_initial_call=True)
def draw_model(ml_type, model, click, model_info_filename, expl_choice, cont_expl_choice, id_tree):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered[0]['prop_id'].split('.')[0]
model_application = page_application.model
# Choice of model
if ihm_id == 'ml_model_choice':
model_application.update_ml_model(value_ml_model)
return None, None
# Choice of information for the model - data, feature mapping ...
elif ihm_id == 'model_info_choice':
try:
model_info = parse_contents_data(model_info, model_info_filename)
model_application.update_info(model_info)
try:
if ctx.triggered:
ihm_id = ctx.triggered_id
if ihm_id == "ml_model_choice" :
return reinit
elif ihm_id == "ml_pretrained_model_choice":
return init_network
elif ihm_id == "submit-model":
if model_application.ml_model is None:
return warning_selection_model, None
elif model_application.pretrained_model is not None:
model_application.update_pretrained_model_layout_with_info(model_info_filename)
return model_application.component.network, None
else:
return warning_selection_model, None
except:
return warning_selection_pretrained_model, None
# Choice of pkl pretrained model
elif ihm_id == 'ml_pretrained_model_choice':
try:
graph = parse_contents_graph(pretrained_model_contents, pretrained_model_filename)
model_application.update_pretrained_model(graph)
if not model_application.add_info:
return warning_selection_model
elif not model_application.add_info:
model_application.update_pretrained_model_layout()
return model_application.component.network, None
return model_application.component.network
elif model_application.model_info is not None:
model_application.update_pretrained_model_layout_with_info(model_info_filename)
return model_application.component.network, None
return model_application.component.network
else:
return warning_selection_data, None
except:
return alert_network, None
# Choice of instance to explain
elif ihm_id == 'ml_instance_choice':
try:
instance = parse_contents_instance(instance_contents, instance_filename)
model_application.update_instance(instance)
return model_application.component.network, model_application.component.explanation
except:
return model_application.component.network, alert_explanation
# Choice of number of expls
elif ihm_id == 'number_explanations':
try:
model_application.update_enum(enum)
return model_application.component.network, model_application.component.explanation
except:
return model_application.component.network, alert_explanation
# Choice of AxP or CxP
elif ihm_id == 'explanation_type':
try:
model_application.update_xtype(xtype)
return model_application.component.network, model_application.component.explanation
except:
return model_application.component.network, alert_explanation
# Choice of solver
elif ihm_id == 'solver_sat':
try:
model_application.update_solver(solver)
return model_application.component.network, model_application.component.explanation
except:
return model_application.component.network, alert_explanation
return warning_selection_data
# Choice of AxP to draw
elif ihm_id == 'expl_choice':
try:
elif ihm_id == 'expl_choice':
model_application.update_expl(expl_choice)
return model_application.component.network, model_application.component.explanation
except:
return alert_network, model_application.component.explanation
return model_application.component.network
# Choice of CxP to draw
elif ihm_id == 'cont_expl_choice':
try:
# Choice of CxP to draw
elif ihm_id == 'cont_expl_choice':
model_application.update_cont_expl(cont_expl_choice)
return model_application.component.network, model_application.component.explanation
except:
return alert_network, model_application.component.explanation
return model_application.component.network
# In the case of RandomForest, id of tree to choose to draw tree
elif ihm_id == 'choice_tree':
try:
# In the case of RandomForest, id of tree to choose to draw tree
elif ihm_id == 'choice_tree':
model_application.update_tree_to_plot(id_tree)
return model_application.component.network, model_application.component.explanation
except:
return alert_network, model_application.component.explanation
return model_application.component.network
except:
return alert_network
# endregion
# region explanation
@app.callback(
Output('explanation', 'hidden'),
Input('explanation', 'children'),
Input('explanation_type', 'value'),
Input('submit-instance', 'n_clicks'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
Input('ml_instance_choice', 'contents'),
prevent_initial_call=True
)
def show_explanation_window(explanation, explanation_type):
if explanation is None or len(explanation_type) == 0:
return True
else:
return False
def show_explanation_window(click, ml_type, model, instance):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered_id
if ihm_id == "ml_model_choice" or ihm_id == "ml_pretrained_model_choice" or ihm_id == "ml_instance_choice":
return True
elif ihm_id == "submit-instance":
return False
else :
return True
@app.callback(
Output('explanation', 'children'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
Input('ml_instance_choice', 'contents'),
Input('submit-instance', 'n_clicks'),
prevent_initial_call=True)
def run_explanation(ml_type, model, instance, click):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered_id
if ihm_id == "ml_model_choice" or ihm_id == "ml_pretrained_model_choice" or ihm_id == "ml_instance_choice":
return reinit
elif ihm_id == "submit-instance":
try:
if model_application.ml_model is None:
return warning_selection_model
elif model_application.pretrained_model is None:
return warning_selection_pretrained_model
elif model_application.instance is not None:
return model_application.component.explanation
else:
return warning_selection_data
except:
return alert_explanation
# endregion
########### RandomForest ###########
@app.callback(
Output('choosing_tree', 'hidden'),
Input('ml_model_choice', 'value'),
Input('graph', 'children'),
prevent_initial_call=True
)
def choose_tree_in_forest(ml_type, graph):
if ml_type == "RandomForest" and graph is not None:
def choose_tree_in_forest(ml_type):
if ml_type == "RandomForest":
return False
else:
return True
......@@ -318,5 +348,4 @@ def register_callbacks(page_home, page_course, page_application, app):
if not bool_draw:
return True, {}, {}
else:
model_application = page_application.model
return False, model_application.options_expls, model_application.options_cont_expls
......@@ -5,14 +5,6 @@ import dash_daq as daq
from pages.application.DecisionTree.DecisionTreeComponent import DecisionTreeComponent
from pages.application.RandomForest.RandomForestComponent import RandomForestComponent
class Application:
def __init__(self, view):
self.view = view
self.model = view.model
class Model:
def __init__(self, names_models, dict_components, dic_solvers, dic_xtypes):
......@@ -55,7 +47,7 @@ class Model:
self.xtypes = self.dic_xtypes[self.ml_model]
self.xtype = [list(self.xtypes.keys())[0]]
#init all params
# init all params
self.component = None
self.pretrained_model = None
self.model_info = None
......@@ -77,6 +69,8 @@ class Model:
self.cont_expl = None
def update_info_needed(self, add_info):
if not add_info:
self.model_info = None
self.add_info = add_info
def update_info(self, model_info):
......@@ -95,7 +89,7 @@ class Model:
def update_instance(self, instance):
self.instance = instance
list_expls, list_cont_expls = self.component.update_with_explicability(self.instance, self.enum,
self.xtype, self.solver)
self.xtype, self.solver)
self.options_expls = {}
self.options_cont_expls = {}
for i in range(len(list_expls)):
......@@ -106,7 +100,7 @@ class Model:
def update_enum(self, enum):
self.enum = enum
list_expls, list_cont_expls = self.component.update_with_explicability(self.instance, self.enum,
self.xtype, self.solver)
self.xtype, self.solver)
self.options_expls = {}
self.options_cont_expls = {}
for i in range(len(list_expls)):
......@@ -114,11 +108,10 @@ class Model:
for i in range(len(list_cont_expls)):
self.options_cont_expls[str(list_cont_expls[i])] = list_cont_expls[i]
def update_xtype(self, xtype):
self.xtype = xtype
list_expls, list_cont_expls = self.component.update_with_explicability(self.instance, self.enum,
self.xtype, self.solver)
self.xtype, self.solver)
self.options_expls = {}
self.options_cont_expls = {}
for i in range(len(list_expls)):
......@@ -129,7 +122,7 @@ class Model:
def update_solver(self, solver):
self.solver = solver
list_expls, list_cont_expls = self.component.update_with_explicability(self.instance, self.enum,
self.xtype, self.solver)
self.xtype, self.solver)
self.options_expls = {}
self.options_cont_expls = {}
for i in range(len(list_expls)):
......@@ -241,8 +234,12 @@ class View:
self.ml_menu_models,
self.add_model_info_choice,
self.model_info,
dcc.Store(id='intermediate-value-data'),
self.pretrained_model_upload,
self.instance_upload], className="sidebar"),
html.Button('Submit model', id='submit-model', n_clicks=0, className="button"),
html.Hr(),
self.instance_upload,
html.Button('Submit instance', id='submit-instance', n_clicks=0, className="button"),], className="sidebar"),
dcc.Tab(label='Advanced Parameters', children=[
html.Br(),
self.num_explanation,
......@@ -251,9 +248,9 @@ class View:
], className="sidebar")])
self.switch = html.Div(id="div_switcher_draw_expl", hidden=True,
children=[html.Label("Draw explanations ?"),
html.Br(),
daq.BooleanSwitch(id='drawing_expl', on=False, color="#FFFFFF", )])
children=[html.Label("Draw explanations ?"),
html.Br(),
daq.BooleanSwitch(id='drawing_expl', on=False, color="#FFFFFF", )])
self.expl_choice = html.Div(id="interaction_graph", hidden=True,
children=[html.H5("Navigate through the explanations and plot them on the tree : "),
......
......@@ -7,7 +7,7 @@ shap==0.40.0
anchor-exp==0.0.2.0
pysmt==0.9.0
anytree==2.8.0
dash==2.1.0
dash==2.4.0
dash_bootstrap_components==1.0.3
dash_daq==0.5.0
dash_interactive_graphviz==0.3.0
......@@ -21,4 +21,5 @@ scikit_learn==1.0.2
six==1.16.0
gunicorn==20.1.0
Werkzeug==2.0.1
pydot==1.4.2
\ No newline at end of file
pydot==1.4.2
Flask-Caching==1.8
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment