-
Caroline DE POURTALES authoredCaroline DE POURTALES authored
callbacks.py 13.01 KiB
import ast
import json
import dash
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc
from dash import Input, Output, State, html
from dash import dcc
from utils import parse_contents_graph, parse_contents_instance, parse_contents_data
from pages.application.application import Model, View
from utils import extract_data
from pages.application.RandomForest.utils import xrf
from pages.application.RandomForest.utils.xrf import *
sys.modules['xrf'] = xrf
from pages.application.DecisionTree.DecisionTreeComponent import DecisionTreeComponent
from pages.application.RandomForest.RandomForestComponent import RandomForestComponent
"""
The callbacks are called whenever there is an interaction with the interface
"""
def register_callbacks(app):
page_list = ['home', 'course', 'application']
# For home directory
page_home = dbc.Row([])
# For course directory
main_course = dcc.Tabs(children=[
dcc.Tab(label='Data format', children=[]),
dcc.Tab(label='Course Decision Tree', children=[])])
page_course = dbc.Row([main_course])
models_data = open('data_retriever.json')
data = json.load(models_data)["data"]
# For the application
names_models, dict_components, dic_solvers, dic_xtypes = extract_data(data)
# region alerts
warning_selection_model = html.Div([dbc.Alert("You didn't choose a king of Machine Learning model first.",
is_open=True,
color='warning',
duration=10000, ), ])
warning_selection_pretrained_model = html.Div([dbc.Alert("You uploaded the data, now upload the pretrained model.",
is_open=True,
color='warning',
duration=10000, ), ])
warning_selection_data = html.Div([dbc.Alert("You uploaded the model, now upload the data.",
is_open=True,
color='warning',
duration=10000, ), ])
alert_network = html.Div([dbc.Alert("There was a problem while computing the graph, read the documentation. \
You might have forgotten to upload the data for Random Forest or you tried to upload an unknown format.",
is_open=True,
color='danger',
duration=10000, ), ])
alert_explanation = html.Div([dbc.Alert(
"There was a problem while computing the explanation. Read the documentation to understand which king of format are accepted.",
is_open=True,
color='danger',
duration=10000, ), ])
reinit = html.Div([dbc.Alert(
"Reinitialization caused by changing model type or pkl or data or instance.",
is_open=True,
color='info',
duration=5000, ), ])
init_network = html.Div([dbc.Alert(
"Initialization.",
is_open=True,
color='info',
duration=5000, ), ])
# endregion
######################################################
@app.callback(
Output('page-content', 'children'),
Input('url', 'pathname'))
def display_page(pathname):
if pathname == '/':
return page_home
if pathname == '/application':
model_application = Model(names_models, dict_components, dic_solvers, dic_xtypes)
return View(model_application).layout
if pathname == '/course':
return page_course
@app.callback(Output('home-link', 'active'),
Output('course-link', 'active'),
Output('application-link', 'active'),
Input('url', 'pathname'))
def navbar_state(pathname):
active_link = ([pathname == f'/{i}' for i in page_list])
return active_link[0], active_link[1], active_link[2]
# region ml type
@app.callback(Output('solver_sat', 'options'),
Output('solver_sat', 'value'),
Output('explanation_type', 'options'),
Output('explanation_type', 'value'),
Input('ml_model_choice', 'value'),
prevent_initial_call=True
)
def update_ml_type_options(ml_type):
if ml_type is not None:
solvers = dic_solvers[ml_type]
solver = solvers[0]
xtypes = dic_xtypes[ml_type]
xtype = [list(xtypes.keys())[0]]
return solvers, solver, xtypes, xtype
else :
return [], None, [], None
# endregion
# region pretrained model
@app.callback(
Output('pretrained_model_filename', 'children'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'filename'),
prevent_initial_call=True)
def select_model(ml_type, filename):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered_id
if ihm_id == 'ml_model_choice':
return None
else :
return filename
# endregion
# region data
@app.callback(
Output('add_info_model_choice', 'on'),
Input('ml_model_choice', 'value'),
prevent_initial_call=True
)
def delete_info(ml_type):
return False
@app.callback(
Output('choice_info_div', 'hidden'),
Input('add_info_model_choice', 'on'),
prevent_initial_call=True
)
def add_model_info(add_info_model_choice):
if add_info_model_choice:
return False
else:
return True
@app.callback(Output('info_filename', 'children'),
Output('intermediate-value-data', 'data'),
Input('ml_model_choice', 'value'),
Input('model_info_choice', 'contents'),
State('model_info_choice', 'filename'),
prevent_initial_call=True
)
def select_data(ml_type, data, filename):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered_id
if ihm_id == 'ml_model_choice':
return None, None
else:
model_info = parse_contents_data(data, filename)
return filename, model_info
# endregion
# region instance
@app.callback(Output('instance_filename', 'children'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
Input('ml_instance_choice', 'contents'),
State('ml_instance_choice', 'filename'),
prevent_initial_call=True
)
def select_instance(ml_type, model, instance, filename):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered_id
if ihm_id == 'ml_model_choice' or ihm_id == "ml_pretrained_model_choice":
return None
else :
return filename
# endregion
# region draw
@app.callback(
Output('graph', 'children'),
Output('explanation', 'children'),
Output('expl_choice', 'options'),
Output('cont_expl_choice', 'options'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
State('ml_pretrained_model_choice', 'filename'),
State('add_info_model_choice', 'on'),
State('intermediate-value-data', 'data'),
Input('ml_instance_choice', 'contents'),
State('ml_instance_choice', 'filename'),
Input('submit-model', 'n_clicks'),
State('model_info_choice', 'filename'),
Input('expl_choice', 'value'),
Input('cont_expl_choice', 'value'),
Input('number_explanations', 'value'),
Input('explanation_type', 'value'),
Input('solver_sat', 'value'),
Input('choice_tree', 'value'),
prevent_initial_call=True)
def draw_model(ml_type, pretrained_model, model_filename, need_data, data, instance, instance_filename, click, model_info_filename,
expl_choice, cont_expl_choice, enum, xtype, solver, id_tree):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered_id
if ihm_id == "ml_model_choice" :
return reinit, None, {}, {}
elif ihm_id == "ml_pretrained_model_choice":
return init_network, None, {}, {}
else :
try:
pretrained_model = parse_contents_graph(pretrained_model, model_filename)
if ml_type is None:
return warning_selection_model, None, {}, {}
else:
component_class = dict_components[ml_type]
component_class = globals()[component_class]
if not need_data:
component = component_class(pretrained_model)
elif data is not None:
component = component_class(pretrained_model, info=data, type_info=model_info_filename)
else:
return warning_selection_data, None, {}, {}
except:
return alert_network, None, {}, {}
else :
if ihm_id == "submit-model":
return component.network, None, {}, {}
# In the case of RandomForest, id of tree to choose to draw tree
elif ihm_id == 'choice_tree':
component.update_plotted_tree(id_tree)
return component.network, component.explanation, {}, {}
elif instance is not None:
instance = parse_contents_instance(instance, instance_filename)
list_expls, list_cont_expls = component.update_with_explicability(instance, enum, xtype, solver)
options_expls = {}
options_cont_expls = {}
for i in range(len(list_expls)):
options_expls[str(list_expls[i])] = list_expls[i]
for i in range(len(list_cont_expls)):
options_cont_expls[str(list_cont_expls[i])] = list_cont_expls[i]
if ihm_id == 'expl_choice':
component.draw_explanation(instance, expl_choice)
return component.network, component.explanation, options_expls, options_cont_expls
# Choice of CxP to draw
elif ihm_id == 'cont_expl_choice':
component.draw_contrastive_explanation(instance, cont_expl_choice)
return component.network, component.explanation, options_expls, options_cont_expls
else :
return component.network, component.explanation, options_expls, options_cont_expls
# endregion
# region explanation
@app.callback(
Output('explanation', 'hidden'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
State('ml_instance_choice', 'contents'),
Input('submit-instance', 'n_clicks'),
prevent_initial_call=True
)
def show_explanation_window(ml_type, model, instante, click):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered_id
if ihm_id == "ml_model_choice" or ihm_id == "ml_pretrained_model_choice" or ihm_id == "ml_instance_choice":
return True
else :
return False
# endregion
# region randomforest
@app.callback(
Output('choosing_tree', 'hidden'),
Input('ml_model_choice', 'value'),
prevent_initial_call=True
)
def choose_tree_in_forest(ml_type):
return bool(ml_type != "RandomForest")
# endregion
# region decisiontree
@app.callback(
Output('div_switcher_draw_expl', 'hidden'),
Input('ml_model_choice', 'value'),
prevent_initial_call=True
)
def show_switcher_draw(ml_type):
return bool(ml_type != "DecisionTree")
@app.callback(
Output('drawing_expl', 'on'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
Input('model_info_choice', 'contents'),
Input('ml_instance_choice', 'contents'),
prevent_initial_call=True
)
def turn_switcher_draw_off(ml_type, model, data, instance):
return False
@app.callback(
Output('interaction_graph', 'hidden'),
Input('drawing_expl', 'on'),
prevent_initial_call=True
)
def switcher_drawing_options(bool_draw):
return not bool_draw
# endregion