-
Caroline DE POURTALES authoredCaroline DE POURTALES authored
callbacks.py 13.60 KiB
import ast
import json
import dash
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc
from dash import Input, Output, State, html
from dash import dcc
from utils import parse_contents_graph, parse_contents_instance, parse_contents_data
from pages.application.application import Model, View
from utils import extract_data
from pages.application.RandomForest.utils import xrf
from pages.application.RandomForest.utils.xrf import *
sys.modules['xrf'] = xrf
def register_callbacks(app):
page_list = ['home', 'course', 'application']
# For home directory
page_home = dbc.Row([])
# For course directory
main_course = dcc.Tabs(children=[
dcc.Tab(label='Data format', children=[]),
dcc.Tab(label='Course Decision Tree', children=[])])
page_course = dbc.Row([main_course])
models_data = open('data_retriever.json')
data = json.load(models_data)["data"]
# For the application
names_models, dict_components, dic_solvers, dic_xtypes = extract_data(data)
model_application = Model(names_models, dict_components, dic_solvers, dic_xtypes)
view_application = View(model_application)
####### Creates alerts when callback fails #########
warning_selection_model = html.Div([dbc.Alert("You didn't choose a king of Machine Learning model first.",
is_open=True,
color='warning',
duration=10000, ), ])
warning_selection_pretrained_model = html.Div([dbc.Alert("You uploaded the data, now upload the pretrained model.",
is_open=True,
color='warning',
duration=10000, ), ])
warning_selection_data = html.Div([dbc.Alert("You uploaded the model, now upload the data.",
is_open=True,
color='warning',
duration=10000, ), ])
alert_network = html.Div([dbc.Alert("There was a problem while computing the graph, read the documentation. \
You might have forgotten to upload the data for Random Forest or you tried to upload an unknown format.",
is_open=True,
color='danger',
duration=10000, ), ])
alert_explanation = html.Div([dbc.Alert(
"There was a problem while computing the explanation. Read the documentation to understand which king of format are accepted.",
is_open=True,
color='danger',
duration=10000, ), ])
reinit = html.Div([dbc.Alert(
"Reinitialization caused by changing model type or pkl or data or instance.",
is_open=True,
color='info',
duration=5000, ), ])
init_network = html.Div([dbc.Alert(
"Initialization.",
is_open=True,
color='info',
duration=5000, ), ])
######################################################
@app.callback(
Output('page-content', 'children'),
Input('url', 'pathname'))
def display_page(pathname):
if pathname == '/':
return page_home
if pathname == '/application':
return view_application.layout
if pathname == '/course':
return page_course
@app.callback(Output('home-link', 'active'),
Output('course-link', 'active'),
Output('application-link', 'active'),
Input('url', 'pathname'))
def navbar_state(pathname):
active_link = ([pathname == f'/{i}' for i in page_list])
return active_link[0], active_link[1], active_link[2]
# region mltype
@app.callback(Output('solver_sat', 'options'),
Output('solver_sat', 'value'),
Output('explanation_type', 'options'),
Output('explanation_type', 'value'),
Input('ml_model_choice', 'value'),
prevent_initial_call=True
)
def update_ml_type_options(ml_type):
model_application.update_ml_model(ml_type)
return model_application.solvers, model_application.solvers[0], model_application.xtypes, [
list(model_application.xtypes.keys())[0]]
# endregion
# region pretrained model
@app.callback(
Output('pretrained_model_filename', 'children'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
State('ml_pretrained_model_choice', 'filename'),
prevent_initial_call=True)
def select_model(ml_type, model, filename):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered_id
if ihm_id == 'ml_model_choice':
return None
else:
graph = parse_contents_graph(model, filename)
model_application.update_pretrained_model(graph)
return filename
# endregion
# region data
@app.callback(
Output('add_info_model_choice', 'on'),
Input('ml_model_choice', 'value'),
prevent_initial_call=True
)
def delete_info(ml_type):
model_application.update_info_needed(False)
return False
@app.callback(
Output('choice_info_div', 'hidden'),
Input('add_info_model_choice', 'on'),
prevent_initial_call=True
)
def add_model_info(add_info_model_choice):
model_application.update_info_needed(add_info_model_choice)
if add_info_model_choice:
return False
else:
return True
@app.callback(Output('info_filename', 'children'),
Output('intermediate-value-data', 'data'),
Input('ml_model_choice', 'value'),
Input('model_info_choice', 'contents'),
State('model_info_choice', 'filename'),
prevent_initial_call=True
)
def select_data(ml_type, data, filename):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered_id
if ihm_id == 'ml_model_choice':
return None, None
else:
model_info = parse_contents_data(data, filename)
model_application.update_info(model_info)
return filename, model_info
# endregion
# region instance
@app.callback(Output('instance_filename', 'children'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
Input('ml_instance_choice', 'contents'),
State('ml_instance_choice', 'filename'),
Input('number_explanations', 'value'),
Input('explanation_type', 'value'),
Input('solver_sat', 'value'),
prevent_initial_call=True
)
def select_instance(ml_type, model, instance, filename, enum, xtype, solver):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered_id
if ihm_id == 'ml_model_choice' or ihm_id == 'ml_pretrained_model_choice':
return None
else:
# Choice of number of expls
if ihm_id == 'number_explanations':
model_application.update_enum(enum)
# Choice of AxP or CxP
elif ihm_id == 'explanation_type':
model_application.update_xtype(xtype)
# Choice of solver
elif ihm_id == 'solver_sat':
model_application.update_solver(solver)
else:
instance = parse_contents_instance(instance, filename)
model_application.update_instance(instance)
return filename
# endregion
# region draw
@app.callback(
Output('graph', 'children'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
Input('submit-model', 'n_clicks'),
State('model_info_choice', 'filename'),
Input('expl_choice', 'value'),
Input('cont_expl_choice', 'value'),
Input('choice_tree', 'value'),
prevent_initial_call=True)
def draw_model(ml_type, model, click, model_info_filename, expl_choice, cont_expl_choice, id_tree):
ctx = dash.callback_context
try:
if ctx.triggered:
ihm_id = ctx.triggered_id
if ihm_id == "ml_model_choice" :
return reinit
elif ihm_id == "ml_pretrained_model_choice":
return init_network
elif ihm_id == "submit-model":
if model_application.ml_model is None:
return warning_selection_model
elif not model_application.add_info:
model_application.update_pretrained_model_layout()
return model_application.component.network
elif model_application.model_info is not None:
model_application.update_pretrained_model_layout_with_info(model_info_filename)
return model_application.component.network
else:
return warning_selection_data
elif ihm_id == 'expl_choice':
model_application.update_expl(expl_choice)
return model_application.component.network
# Choice of CxP to draw
elif ihm_id == 'cont_expl_choice':
model_application.update_cont_expl(cont_expl_choice)
return model_application.component.network
# In the case of RandomForest, id of tree to choose to draw tree
elif ihm_id == 'choice_tree':
model_application.update_tree_to_plot(id_tree)
return model_application.component.network
except:
return alert_network
# endregion
# region explanation
@app.callback(
Output('explanation', 'hidden'),
Input('submit-instance', 'n_clicks'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
Input('ml_instance_choice', 'contents'),
prevent_initial_call=True
)
def show_explanation_window(click, ml_type, model, instance):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered_id
if ihm_id == "ml_model_choice" or ihm_id == "ml_pretrained_model_choice" or ihm_id == "ml_instance_choice":
return True
elif ihm_id == "submit-instance":
return False
else :
return True
@app.callback(
Output('explanation', 'children'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
Input('ml_instance_choice', 'contents'),
Input('submit-instance', 'n_clicks'),
prevent_initial_call=True)
def run_explanation(ml_type, model, instance, click):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered_id
if ihm_id == "ml_model_choice" or ihm_id == "ml_pretrained_model_choice" or ihm_id == "ml_instance_choice":
return reinit
elif ihm_id == "submit-instance":
try:
if model_application.ml_model is None:
return warning_selection_model
elif model_application.pretrained_model is None:
return warning_selection_pretrained_model
elif model_application.instance is not None:
return model_application.component.explanation
else:
return warning_selection_data
except:
return alert_explanation
# endregion
########### RandomForest ###########
@app.callback(
Output('choosing_tree', 'hidden'),
Input('ml_model_choice', 'value'),
prevent_initial_call=True
)
def choose_tree_in_forest(ml_type):
if ml_type == "RandomForest":
return False
else:
return True
########### DecistionTree ###########
@app.callback(
Output('div_switcher_draw_expl', 'hidden'),
Input('ml_model_choice', 'value'),
prevent_initial_call=True
)
def show_switcher_draw(ml_type):
if ml_type != "DecisionTree":
return True
else:
return False
@app.callback(
Output('drawing_expl', 'on'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
Input('model_info_choice', 'contents'),
Input('ml_instance_choice', 'contents'),
prevent_initial_call=True
)
def turn_switcher_draw_off(ml_type, model, data, instance):
return False
@app.callback(
Output('interaction_graph', 'hidden'),
Output('expl_choice', 'options'),
Output('cont_expl_choice', 'options'),
Input('drawing_expl', 'on'),
prevent_initial_call=True
)
def switcher_drawing_options(bool_draw):
if not bool_draw:
return True, {}, {}
else:
return False, model_application.options_expls, model_application.options_cont_expls