Skip to content
Snippets Groups Projects
callbacks.py 13.60 KiB
import ast
import json
import dash
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc
from dash import Input, Output, State, html
from dash import dcc

from utils import parse_contents_graph, parse_contents_instance, parse_contents_data
from pages.application.application import Model, View
from utils import extract_data

from pages.application.RandomForest.utils import xrf
from pages.application.RandomForest.utils.xrf import *
sys.modules['xrf'] = xrf


def register_callbacks(app):
    page_list = ['home', 'course', 'application']

    # For home directory
    page_home = dbc.Row([])
    # For course directory
    main_course = dcc.Tabs(children=[
        dcc.Tab(label='Data format', children=[]),
        dcc.Tab(label='Course Decision Tree', children=[])])
    page_course = dbc.Row([main_course])

    models_data = open('data_retriever.json')
    data = json.load(models_data)["data"]
    # For the application
    names_models, dict_components, dic_solvers, dic_xtypes = extract_data(data)
    model_application = Model(names_models, dict_components, dic_solvers, dic_xtypes)
    view_application = View(model_application)

    ####### Creates alerts when callback fails #########
    warning_selection_model = html.Div([dbc.Alert("You didn't choose a king of Machine Learning model first.",
                                                  is_open=True,
                                                  color='warning',
                                                  duration=10000, ), ])
    warning_selection_pretrained_model = html.Div([dbc.Alert("You uploaded the data, now upload the pretrained model.",
                                                             is_open=True,
                                                             color='warning',
                                                             duration=10000, ), ])
    warning_selection_data = html.Div([dbc.Alert("You uploaded the model, now upload the data.",
                                                 is_open=True,
                                                 color='warning',
                                                 duration=10000, ), ])
    alert_network = html.Div([dbc.Alert("There was a problem while computing the graph, read the documentation. \
                                        You might have forgotten to upload the data for Random Forest or you tried to upload an unknown format.",
                                        is_open=True,
                                        color='danger',
                                        duration=10000, ), ])
    alert_explanation = html.Div([dbc.Alert(
        "There was a problem while computing the explanation. Read the documentation to understand which king of format are accepted.",
        is_open=True,
        color='danger',
        duration=10000, ), ])

    reinit = html.Div([dbc.Alert(
        "Reinitialization caused by changing model type or pkl or data or instance.",
        is_open=True,
        color='info',
        duration=5000, ), ])

    init_network = html.Div([dbc.Alert(
        "Initialization.",
        is_open=True,
        color='info',
        duration=5000, ), ])

    ######################################################

    @app.callback(
        Output('page-content', 'children'),
        Input('url', 'pathname'))
    def display_page(pathname):
        if pathname == '/':
            return page_home
        if pathname == '/application':
            return view_application.layout
        if pathname == '/course':
            return page_course

    @app.callback(Output('home-link', 'active'),
                  Output('course-link', 'active'),
                  Output('application-link', 'active'),
                  Input('url', 'pathname'))
    def navbar_state(pathname):
        active_link = ([pathname == f'/{i}' for i in page_list])
        return active_link[0], active_link[1], active_link[2]

    # region mltype
    @app.callback(Output('solver_sat', 'options'),
                  Output('solver_sat', 'value'),
                  Output('explanation_type', 'options'),
                  Output('explanation_type', 'value'),
                  Input('ml_model_choice', 'value'),
                  prevent_initial_call=True
                  )
    def update_ml_type_options(ml_type):
        model_application.update_ml_model(ml_type)
        return model_application.solvers, model_application.solvers[0], model_application.xtypes, [
            list(model_application.xtypes.keys())[0]]

    # endregion

    # region pretrained model
    @app.callback(
        Output('pretrained_model_filename', 'children'),
        Input('ml_model_choice', 'value'),
        Input('ml_pretrained_model_choice', 'contents'),
        State('ml_pretrained_model_choice', 'filename'),
        prevent_initial_call=True)
    def select_model(ml_type, model, filename):
        ctx = dash.callback_context
        if ctx.triggered:
            ihm_id = ctx.triggered_id
            if ihm_id == 'ml_model_choice':
                return None
            else:
                graph = parse_contents_graph(model, filename)
                model_application.update_pretrained_model(graph)
                return filename

    # endregion

    # region data
    @app.callback(
        Output('add_info_model_choice', 'on'),
        Input('ml_model_choice', 'value'),
        prevent_initial_call=True
    )
    def delete_info(ml_type):
        model_application.update_info_needed(False)
        return False

    @app.callback(
        Output('choice_info_div', 'hidden'),
        Input('add_info_model_choice', 'on'),
        prevent_initial_call=True
    )
    def add_model_info(add_info_model_choice):
        model_application.update_info_needed(add_info_model_choice)
        if add_info_model_choice:
            return False
        else:
            return True

    @app.callback(Output('info_filename', 'children'),
                  Output('intermediate-value-data', 'data'),
                  Input('ml_model_choice', 'value'),
                  Input('model_info_choice', 'contents'),
                  State('model_info_choice', 'filename'),
                  prevent_initial_call=True
                  )
    def select_data(ml_type, data, filename):
        ctx = dash.callback_context
        if ctx.triggered:
            ihm_id = ctx.triggered_id
            if ihm_id == 'ml_model_choice':
                return None, None
            else:
                model_info = parse_contents_data(data, filename)
                model_application.update_info(model_info)
                return filename, model_info

    # endregion

    # region instance
    @app.callback(Output('instance_filename', 'children'),
                  Input('ml_model_choice', 'value'),
                  Input('ml_pretrained_model_choice', 'contents'),
                  Input('ml_instance_choice', 'contents'),
                  State('ml_instance_choice', 'filename'),
                  Input('number_explanations', 'value'),
                  Input('explanation_type', 'value'),
                  Input('solver_sat', 'value'),
                  prevent_initial_call=True
                  )
    def select_instance(ml_type, model, instance, filename, enum, xtype, solver):
        ctx = dash.callback_context
        if ctx.triggered:
            ihm_id = ctx.triggered_id
            if ihm_id == 'ml_model_choice' or ihm_id == 'ml_pretrained_model_choice':
                return None
            else:
                # Choice of number of expls
                if ihm_id == 'number_explanations':
                    model_application.update_enum(enum)
                # Choice of AxP or CxP
                elif ihm_id == 'explanation_type':
                    model_application.update_xtype(xtype)
                # Choice of solver
                elif ihm_id == 'solver_sat':
                    model_application.update_solver(solver)
                else:
                    instance = parse_contents_instance(instance, filename)
                    model_application.update_instance(instance)
                return filename

    # endregion

    # region draw
    @app.callback(
        Output('graph', 'children'),
        Input('ml_model_choice', 'value'),
        Input('ml_pretrained_model_choice', 'contents'),
        Input('submit-model', 'n_clicks'),
        State('model_info_choice', 'filename'),
        Input('expl_choice', 'value'),
        Input('cont_expl_choice', 'value'),
        Input('choice_tree', 'value'),
        prevent_initial_call=True)
    def draw_model(ml_type, model, click, model_info_filename, expl_choice, cont_expl_choice, id_tree):
        ctx = dash.callback_context
        try:
            if ctx.triggered:
                ihm_id = ctx.triggered_id
                if ihm_id == "ml_model_choice" :
                    return reinit
                elif ihm_id == "ml_pretrained_model_choice":
                    return init_network
                elif ihm_id == "submit-model":
                    if model_application.ml_model is None:
                        return warning_selection_model
                    elif not model_application.add_info:
                        model_application.update_pretrained_model_layout()
                        return model_application.component.network
                    elif model_application.model_info is not None:
                        model_application.update_pretrained_model_layout_with_info(model_info_filename)
                        return model_application.component.network
                    else:
                        return warning_selection_data

                elif ihm_id == 'expl_choice':
                    model_application.update_expl(expl_choice)
                    return model_application.component.network

                # Choice of CxP to draw
                elif ihm_id == 'cont_expl_choice':
                    model_application.update_cont_expl(cont_expl_choice)
                    return model_application.component.network

                # In the case of RandomForest, id of tree to choose to draw tree
                elif ihm_id == 'choice_tree':
                    model_application.update_tree_to_plot(id_tree)
                    return model_application.component.network
        except:
            return alert_network

    # endregion

    # region explanation

    @app.callback(
        Output('explanation', 'hidden'),
        Input('submit-instance', 'n_clicks'),
        Input('ml_model_choice', 'value'),
        Input('ml_pretrained_model_choice', 'contents'),
        Input('ml_instance_choice', 'contents'),
        prevent_initial_call=True
    )
    def show_explanation_window(click, ml_type, model, instance):
        ctx = dash.callback_context
        if ctx.triggered:
            ihm_id = ctx.triggered_id
            if ihm_id == "ml_model_choice" or ihm_id == "ml_pretrained_model_choice" or ihm_id == "ml_instance_choice":
                return True
            elif ihm_id == "submit-instance":
                return False
            else :
                return True

    @app.callback(
        Output('explanation', 'children'),
        Input('ml_model_choice', 'value'),
        Input('ml_pretrained_model_choice', 'contents'),
        Input('ml_instance_choice', 'contents'),
        Input('submit-instance', 'n_clicks'),
        prevent_initial_call=True)
    def run_explanation(ml_type, model, instance, click):
        ctx = dash.callback_context
        if ctx.triggered:
            ihm_id = ctx.triggered_id
            if ihm_id == "ml_model_choice" or ihm_id == "ml_pretrained_model_choice" or ihm_id == "ml_instance_choice":
                return reinit
            elif ihm_id == "submit-instance":
                try:
                    if model_application.ml_model is None:
                        return warning_selection_model
                    elif model_application.pretrained_model is None:
                        return warning_selection_pretrained_model
                    elif model_application.instance is not None:
                        return model_application.component.explanation
                    else:
                        return warning_selection_data
                except:
                    return alert_explanation

    # endregion

    ########### RandomForest ###########

    @app.callback(
        Output('choosing_tree', 'hidden'),
        Input('ml_model_choice', 'value'),
        prevent_initial_call=True
    )
    def choose_tree_in_forest(ml_type):
        if ml_type == "RandomForest":
            return False
        else:
            return True

    ########### DecistionTree ###########

    @app.callback(
        Output('div_switcher_draw_expl', 'hidden'),
        Input('ml_model_choice', 'value'),
        prevent_initial_call=True
    )
    def show_switcher_draw(ml_type):
        if ml_type != "DecisionTree":
            return True
        else:
            return False

    @app.callback(
        Output('drawing_expl', 'on'),
        Input('ml_model_choice', 'value'),
        Input('ml_pretrained_model_choice', 'contents'),
        Input('model_info_choice', 'contents'),
        Input('ml_instance_choice', 'contents'),
        prevent_initial_call=True
    )
    def turn_switcher_draw_off(ml_type, model, data, instance):
        return False

    @app.callback(
        Output('interaction_graph', 'hidden'),
        Output('expl_choice', 'options'),
        Output('cont_expl_choice', 'options'),
        Input('drawing_expl', 'on'),
        prevent_initial_call=True
    )
    def switcher_drawing_options(bool_draw):
        if not bool_draw:
            return True, {}, {}
        else:
            return False, model_application.options_expls, model_application.options_cont_expls