Skip to content
Snippets Groups Projects
callbacks.py 13.01 KiB
import ast
import json
import dash
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc
from dash import Input, Output, State, html
from dash import dcc

from utils import parse_contents_graph, parse_contents_instance, parse_contents_data
from pages.application.application import Model, View
from utils import extract_data

from pages.application.RandomForest.utils import xrf
from pages.application.RandomForest.utils.xrf import *
sys.modules['xrf'] = xrf

from pages.application.DecisionTree.DecisionTreeComponent import DecisionTreeComponent
from pages.application.RandomForest.RandomForestComponent import RandomForestComponent

"""
The callbacks are called whenever there is an interaction with the interface
"""


def register_callbacks(app):
    page_list = ['home', 'course', 'application']

    # For home directory
    page_home = dbc.Row([])
    # For course directory
    main_course = dcc.Tabs(children=[
        dcc.Tab(label='Data format', children=[]),
        dcc.Tab(label='Course Decision Tree', children=[])])
    page_course = dbc.Row([main_course])

    models_data = open('data_retriever.json')
    data = json.load(models_data)["data"]
    # For the application
    names_models, dict_components, dic_solvers, dic_xtypes = extract_data(data)

    # region alerts
    warning_selection_model = html.Div([dbc.Alert("You didn't choose a king of Machine Learning model first.",
                                                  is_open=True,
                                                  color='warning',
                                                  duration=10000, ), ])
    warning_selection_pretrained_model = html.Div([dbc.Alert("You uploaded the data, now upload the pretrained model.",
                                                             is_open=True,
                                                             color='warning',
                                                             duration=10000, ), ])
    warning_selection_data = html.Div([dbc.Alert("You uploaded the model, now upload the data.",
                                                 is_open=True,
                                                 color='warning',
                                                 duration=10000, ), ])
    alert_network = html.Div([dbc.Alert("There was a problem while computing the graph, read the documentation. \
                                        You might have forgotten to upload the data for Random Forest or you tried to upload an unknown format.",
                                        is_open=True,
                                        color='danger',
                                        duration=10000, ), ])
    alert_explanation = html.Div([dbc.Alert(
        "There was a problem while computing the explanation. Read the documentation to understand which king of format are accepted.",
        is_open=True,
        color='danger',
        duration=10000, ), ])

    reinit = html.Div([dbc.Alert(
        "Reinitialization caused by changing model type or pkl or data or instance.",
        is_open=True,
        color='info',
        duration=5000, ), ])

    init_network = html.Div([dbc.Alert(
        "Initialization.",
        is_open=True,
        color='info',
        duration=5000, ), ])

    # endregion

    ######################################################

    @app.callback(
        Output('page-content', 'children'),
        Input('url', 'pathname'))
    def display_page(pathname):
        if pathname == '/':
            return page_home
        if pathname == '/application':
            model_application = Model(names_models, dict_components, dic_solvers, dic_xtypes)
            return View(model_application).layout
        if pathname == '/course':
            return page_course

    @app.callback(Output('home-link', 'active'),
                  Output('course-link', 'active'),
                  Output('application-link', 'active'),
                  Input('url', 'pathname'))
    def navbar_state(pathname):
        active_link = ([pathname == f'/{i}' for i in page_list])
        return active_link[0], active_link[1], active_link[2]

    # region ml type
    @app.callback(Output('solver_sat', 'options'),
                  Output('solver_sat', 'value'),
                  Output('explanation_type', 'options'),
                  Output('explanation_type', 'value'),
                  Input('ml_model_choice', 'value'),
                  prevent_initial_call=True
                  )
    def update_ml_type_options(ml_type):
        if ml_type is not None:
            solvers = dic_solvers[ml_type]
            solver = solvers[0]
            xtypes = dic_xtypes[ml_type]
            xtype = [list(xtypes.keys())[0]]
            return solvers, solver, xtypes, xtype
        else :
            return [], None, [], None

    # endregion

    # region pretrained model
    @app.callback(
        Output('pretrained_model_filename', 'children'),
        Input('ml_model_choice', 'value'),
        Input('ml_pretrained_model_choice', 'filename'),
        prevent_initial_call=True)
    def select_model(ml_type, filename):
        ctx = dash.callback_context
        if ctx.triggered:
            ihm_id = ctx.triggered_id
            if ihm_id == 'ml_model_choice':
                return None
            else :
                return filename

    # endregion

    # region data
    @app.callback(
        Output('add_info_model_choice', 'on'),
        Input('ml_model_choice', 'value'),
        prevent_initial_call=True
    )
    def delete_info(ml_type):
        return False

    @app.callback(
        Output('choice_info_div', 'hidden'),
        Input('add_info_model_choice', 'on'),
        prevent_initial_call=True
    )
    def add_model_info(add_info_model_choice):
        if add_info_model_choice:
            return False
        else:
            return True

    @app.callback(Output('info_filename', 'children'),
                  Output('intermediate-value-data', 'data'),
                  Input('ml_model_choice', 'value'),
                  Input('model_info_choice', 'contents'),
                  State('model_info_choice', 'filename'),
                  prevent_initial_call=True
                  )
    def select_data(ml_type, data, filename):
        ctx = dash.callback_context
        if ctx.triggered:
            ihm_id = ctx.triggered_id
            if ihm_id == 'ml_model_choice':
                return None, None
            else:
                model_info = parse_contents_data(data, filename)
                return filename, model_info

    # endregion

    # region instance
    @app.callback(Output('instance_filename', 'children'),
                  Input('ml_model_choice', 'value'),
                  Input('ml_pretrained_model_choice', 'contents'),
                  Input('ml_instance_choice', 'contents'),
                  State('ml_instance_choice', 'filename'),
                  prevent_initial_call=True
                  )
    def select_instance(ml_type, model, instance, filename):
        ctx = dash.callback_context
        if ctx.triggered:
            ihm_id = ctx.triggered_id
            if ihm_id == 'ml_model_choice' or ihm_id == "ml_pretrained_model_choice":
                return None
            else :
                return filename

    # endregion

    # region draw
    @app.callback(
        Output('graph', 'children'),
        Output('explanation', 'children'),
        Output('expl_choice', 'options'),
        Output('cont_expl_choice', 'options'),
        Input('ml_model_choice', 'value'),
        Input('ml_pretrained_model_choice', 'contents'),
        State('ml_pretrained_model_choice', 'filename'),
        State('add_info_model_choice', 'on'),
        State('intermediate-value-data', 'data'),
        Input('ml_instance_choice', 'contents'),
        State('ml_instance_choice', 'filename'),
        Input('submit-model', 'n_clicks'),
        State('model_info_choice', 'filename'),
        Input('expl_choice', 'value'),
        Input('cont_expl_choice', 'value'),
        Input('number_explanations', 'value'),
        Input('explanation_type', 'value'),
        Input('solver_sat', 'value'),
        Input('choice_tree', 'value'),
        prevent_initial_call=True)
    def draw_model(ml_type, pretrained_model, model_filename, need_data, data, instance, instance_filename, click, model_info_filename,
                   expl_choice, cont_expl_choice, enum, xtype, solver, id_tree):
        ctx = dash.callback_context
        if ctx.triggered:
            ihm_id = ctx.triggered_id

            if ihm_id == "ml_model_choice" :
                return reinit, None, {}, {}

            elif ihm_id == "ml_pretrained_model_choice":
                return init_network, None, {}, {}

            else :
                try:
                    pretrained_model = parse_contents_graph(pretrained_model, model_filename)
                    if ml_type is None:
                        return warning_selection_model, None, {}, {}
                    else:
                        component_class = dict_components[ml_type]
                        component_class = globals()[component_class]
                        if not need_data:
                            component = component_class(pretrained_model)
                        elif data is not None:
                            component = component_class(pretrained_model, info=data, type_info=model_info_filename)
                        else:
                            return warning_selection_data, None, {}, {}

                except:
                    return alert_network, None, {}, {}

                else :
                    if ihm_id == "submit-model":
                        return component.network, None, {}, {}

                    # In the case of RandomForest, id of tree to choose to draw tree
                    elif ihm_id == 'choice_tree':
                        component.update_plotted_tree(id_tree)
                        return component.network, component.explanation, {}, {}

                    elif instance is not None:
                        instance = parse_contents_instance(instance, instance_filename)
                        list_expls, list_cont_expls = component.update_with_explicability(instance, enum, xtype, solver)
                        options_expls = {}
                        options_cont_expls = {}
                        for i in range(len(list_expls)):
                            options_expls[str(list_expls[i])] = list_expls[i]
                        for i in range(len(list_cont_expls)):
                            options_cont_expls[str(list_cont_expls[i])] = list_cont_expls[i]

                        if ihm_id == 'expl_choice':
                            component.draw_explanation(instance, expl_choice)
                            return component.network, component.explanation, options_expls, options_cont_expls

                        # Choice of CxP to draw
                        elif ihm_id == 'cont_expl_choice':
                            component.draw_contrastive_explanation(instance, cont_expl_choice)
                            return component.network, component.explanation, options_expls, options_cont_expls

                        else :
                            return component.network, component.explanation, options_expls, options_cont_expls

    # endregion

    # region explanation

    @app.callback(
        Output('explanation', 'hidden'),
        Input('ml_model_choice', 'value'),
        Input('ml_pretrained_model_choice', 'contents'),
        State('ml_instance_choice', 'contents'),
        Input('submit-instance', 'n_clicks'),
        prevent_initial_call=True
    )
    def show_explanation_window(ml_type, model, instante, click):
        ctx = dash.callback_context
        if ctx.triggered:
            ihm_id = ctx.triggered_id
            if ihm_id == "ml_model_choice" or ihm_id == "ml_pretrained_model_choice" or ihm_id == "ml_instance_choice":
                return True
            else :
                return False

    # endregion

    # region randomforest

    @app.callback(
        Output('choosing_tree', 'hidden'),
        Input('ml_model_choice', 'value'),
        prevent_initial_call=True
    )
    def choose_tree_in_forest(ml_type):
        return bool(ml_type != "RandomForest")

    # endregion

    # region decisiontree

    @app.callback(
        Output('div_switcher_draw_expl', 'hidden'),
        Input('ml_model_choice', 'value'),
        prevent_initial_call=True
    )
    def show_switcher_draw(ml_type):
        return bool(ml_type != "DecisionTree")

    @app.callback(
        Output('drawing_expl', 'on'),
        Input('ml_model_choice', 'value'),
        Input('ml_pretrained_model_choice', 'contents'),
        Input('model_info_choice', 'contents'),
        Input('ml_instance_choice', 'contents'),
        prevent_initial_call=True
    )
    def turn_switcher_draw_off(ml_type, model, data, instance):
        return False

    @app.callback(
        Output('interaction_graph', 'hidden'),
        Input('drawing_expl', 'on'),
        prevent_initial_call=True
    )
    def switcher_drawing_options(bool_draw):
        return not bool_draw

    # endregion