diff --git a/callbacks.py b/callbacks.py index 06b112f9a3e78088f141fe86f96c473977e239bd..906b09c41a107b7943cb9842a8a4a293fd5a649a 100644 --- a/callbacks.py +++ b/callbacks.py @@ -14,6 +14,9 @@ from pages.application.RandomForest.utils import xrf from pages.application.RandomForest.utils.xrf import * sys.modules['xrf'] = xrf +from sklearn.ensemble._voting import VotingClassifier +from sklearn.ensemble import RandomForestClassifier + from pages.application.DecisionTree.DecisionTreeComponent import DecisionTreeComponent from pages.application.RandomForest.RandomForestComponent import RandomForestComponent @@ -30,12 +33,13 @@ def register_callbacks(app): # For course directory main_course = dcc.Tabs(children=[ dcc.Tab(label='Data format', children=[]), - dcc.Tab(label='Course Decision Tree', children=[])]) + dcc.Tab(label='Course Decision Tree', children=[]), + dcc.Tab(label='Course Random Forest', children=[])]) page_course = dbc.Row([main_course]) + # For the application models_data = open('data_retriever.json') data = json.load(models_data)["data"] - # For the application names_models, dict_components, dic_solvers, dic_xtypes = extract_data(data) # region alerts @@ -43,7 +47,7 @@ def register_callbacks(app): is_open=True, color='warning', duration=10000, ), ]) - warning_selection_pretrained_model = html.Div([dbc.Alert("You uploaded the data, now upload the pretrained model.", + warning_selection_pretrained_model = html.Div([dbc.Alert("Upload the pretrained model.", is_open=True, color='warning', duration=10000, ), ]) @@ -113,7 +117,7 @@ def register_callbacks(app): xtypes = dic_xtypes[ml_type] xtype = [list(xtypes.keys())[0]] return solvers, solver, xtypes, xtype - else : + else: return [], None, [], None # endregion @@ -130,7 +134,7 @@ def register_callbacks(app): ihm_id = ctx.triggered_id if ihm_id == 'ml_model_choice': return None - else : + else: return filename # endregion @@ -178,22 +182,21 @@ def register_callbacks(app): @app.callback(Output('instance_filename', 'children'), Input('ml_model_choice', 'value'), Input('ml_pretrained_model_choice', 'contents'), - Input('ml_instance_choice', 'contents'), - State('ml_instance_choice', 'filename'), + Input('ml_instance_choice', 'filename'), prevent_initial_call=True ) - def select_instance(ml_type, model, instance, filename): + def select_instance(ml_type, model, filename): ctx = dash.callback_context if ctx.triggered: ihm_id = ctx.triggered_id if ihm_id == 'ml_model_choice' or ihm_id == "ml_pretrained_model_choice": return None - else : + else: return filename # endregion - # region draw + # region main @app.callback( Output('graph', 'children'), Output('explanation', 'children'), @@ -215,24 +218,31 @@ def register_callbacks(app): Input('solver_sat', 'value'), Input('choice_tree', 'value'), prevent_initial_call=True) - def draw_model(ml_type, pretrained_model, model_filename, need_data, data, instance, instance_filename, click, model_info_filename, + def draw_model(ml_type, pretrained_model, model_filename, need_data, data, instance, instance_filename, click, + model_info_filename, expl_choice, cont_expl_choice, enum, xtype, solver, id_tree): ctx = dash.callback_context if ctx.triggered: ihm_id = ctx.triggered_id - if ihm_id == "ml_model_choice" : + # selecting a machine learning model type initialize the component + if ihm_id == "ml_model_choice": return reinit, None, {}, {} + # uploading a model elif ihm_id == "ml_pretrained_model_choice": return init_network, None, {}, {} - else : + else: + # construction of the component base from pretrained model + # catching exception if the construction of the component fails with try/except/else try: - pretrained_model = parse_contents_graph(pretrained_model, model_filename) if ml_type is None: return warning_selection_model, None, {}, {} + elif pretrained_model is None : + return warning_selection_pretrained_model, None, {}, {} else: + pretrained_model = parse_contents_graph(pretrained_model, model_filename) component_class = dict_components[ml_type] component_class = globals()[component_class] if not need_data: @@ -241,40 +251,33 @@ def register_callbacks(app): component = component_class(pretrained_model, info=data, type_info=model_info_filename) else: return warning_selection_data, None, {}, {} - except: return alert_network, None, {}, {} - else : + else: + # plotting model by clicking "submit" button if ihm_id == "submit-model": return component.network, None, {}, {} + # construction of explanation + if instance is not None: + try: + instance = parse_contents_instance(instance, instance_filename) + component.update_with_explicability(instance, enum, xtype, solver) + # In the case of DecisionTree, plotting explanation + if ihm_id == 'expl_choice': + component.draw_explanation(instance, expl_choice) + # # In the case of DecisionTree, plotting cont explanation + elif ihm_id == 'cont_expl_choice': + component.draw_contrastive_explanation(instance, cont_expl_choice) + except: + return component.network, alert_explanation, {}, {} + # In the case of RandomForest, id of tree to choose to draw tree - elif ihm_id == 'choice_tree': + if ihm_id == 'choice_tree': component.update_plotted_tree(id_tree) - return component.network, component.explanation, {}, {} - - elif instance is not None: - instance = parse_contents_instance(instance, instance_filename) - list_expls, list_cont_expls = component.update_with_explicability(instance, enum, xtype, solver) - options_expls = {} - options_cont_expls = {} - for i in range(len(list_expls)): - options_expls[str(list_expls[i])] = list_expls[i] - for i in range(len(list_cont_expls)): - options_cont_expls[str(list_cont_expls[i])] = list_cont_expls[i] - - if ihm_id == 'expl_choice': - component.draw_explanation(instance, expl_choice) - return component.network, component.explanation, options_expls, options_cont_expls - - # Choice of CxP to draw - elif ihm_id == 'cont_expl_choice': - component.draw_contrastive_explanation(instance, cont_expl_choice) - return component.network, component.explanation, options_expls, options_cont_expls - - else : - return component.network, component.explanation, options_expls, options_cont_expls + + return component.network, component.explanation, component.options_expls, component.options_cont_expls # endregion @@ -284,17 +287,16 @@ def register_callbacks(app): Output('explanation', 'hidden'), Input('ml_model_choice', 'value'), Input('ml_pretrained_model_choice', 'contents'), - State('ml_instance_choice', 'contents'), Input('submit-instance', 'n_clicks'), prevent_initial_call=True ) - def show_explanation_window(ml_type, model, instante, click): + def show_explanation_window(ml_type, model, click): ctx = dash.callback_context if ctx.triggered: ihm_id = ctx.triggered_id - if ihm_id == "ml_model_choice" or ihm_id == "ml_pretrained_model_choice" or ihm_id == "ml_instance_choice": + if ihm_id == "ml_model_choice" or ihm_id == "ml_pretrained_model_choice": return True - else : + else: return False # endregion @@ -309,6 +311,24 @@ def register_callbacks(app): def choose_tree_in_forest(ml_type): return bool(ml_type != "RandomForest") + @app.callback( + Output("choice_tree", "max"), + State('ml_model_choice', 'value'), + Input('ml_pretrained_model_choice', 'contents'), + State('ml_pretrained_model_choice', 'filename'), + prevent_initial_call=True) + def adjust_slider_max(ml_type, pretrained_model, model_filename): + if ml_type == "RandomForest": + pretrained_model = parse_contents_graph(pretrained_model, model_filename) + if isinstance(pretrained_model, xrf.rndmforest.RF2001): + return int(pretrained_model.forest.n_estimators) + elif isinstance(pretrained_model, RandomForestClassifier): + return pretrained_model.n_estimators + elif isinstance(pretrained_model, VotingClassifier): + return len(pretrained_model.estimators) + else: + return 0 + # endregion # region decisiontree @@ -340,4 +360,4 @@ def register_callbacks(app): def switcher_drawing_options(bool_draw): return not bool_draw - # endregion \ No newline at end of file + # endregion diff --git a/data_retriever.json b/data_retriever.json index 1607a1a1fcc43cca6c7025fa1ea37db6076e535f..e0051a0150d14c7b9315b1d50b3907f1f72ae4ad 100644 --- a/data_retriever.json +++ b/data_retriever.json @@ -8,21 +8,20 @@ "g3", "g4", "lgl", "mcb", "mcm", "mpl", "m22", "mc", "mgh" ], "xtypes" : { - "AXp": "Abductive Explanation", "CXp": "Contrastive explanation"} - }, - { - "ml_type" : "NaiveBayes", - "component" : "NaiveBayesComponent", - "solvers" : [], - "xtypes" : { - "AXp": "Abductive Explanation", "CXp": "Contrastive explanation"} + "AXp": " Abductive ", "CXp": " Contrastive "} }, { "ml_type" : "RandomForest", "component" : "RandomForestComponent", "solvers" : ["SAT"], - "xtypes" : {"M": "Minimal explanation"} + "xtypes" : {"abd": " Abductive ", "con" : " Contrastive "} + }, + { + "ml_type" : "NaiveBayes", + "component" : "NaiveBayesComponent", + "solvers" : ["SAT"], + "xtypes" : {"abd": " Abductive ", "con" : " Contrastive "} } ] -} \ No newline at end of file +} diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py index 81f00168f609cc3bd2261bb5c4de42ae84542e18..96f8146573b3cd4394c10144ca3580da70d3478a 100644 --- a/pages/application/DecisionTree/DecisionTreeComponent.py +++ b/pages/application/DecisionTree/DecisionTreeComponent.py @@ -62,6 +62,8 @@ class DecisionTreeComponent: "background-color": "transparent"})]) # init explanation self.explanation = [] + self.options_cont_expls = {} + self.options_expls = {} def create_fvmap_inverse(self, instance): def create_fvmap_inverse_with_info(features_names_mapping): @@ -158,7 +160,12 @@ class DecisionTreeComponent: "background-color": "transparent"} )]) - return list_explanations_path, list_contrastive_explanations_path + self.options_expls = {} + self.options_cont_expls = {} + for i in range(len(list_explanations_path)): + self.options_expls[str(list_explanations_path[i])] = list_explanations_path[i] + for i in range(len(list_contrastive_explanations_path)): + self.options_cont_expls[str(list_contrastive_explanations_path[i])] = list_contrastive_explanations_path[i] def draw_explanation(self, instance, expl): r""" Called with the selection of an explanation to plot on the tree diff --git a/pages/application/RandomForest/RandomForestComponent.py b/pages/application/RandomForest/RandomForestComponent.py index 38522ba30353ae06e6a54059f5dde35e587fecdd..19704e42a8ef3dcb913c3c2d27c2459225eaf0ba 100644 --- a/pages/application/RandomForest/RandomForestComponent.py +++ b/pages/application/RandomForest/RandomForestComponent.py @@ -1,8 +1,13 @@ +import re + +import numpy from dash import html import dash_interactive_graphviz from sklearn import tree from pages.application.RandomForest.utils import xrf from pages.application.RandomForest.utils.xrf.xforest import XRF, Dataset +from sklearn.ensemble._voting import VotingClassifier +from sklearn.ensemble import RandomForestClassifier class RandomForestComponent: @@ -15,14 +20,25 @@ class RandomForestComponent: # creation of model if info is not None and 'csv' in type_info: - self.random_forest = XRF(model, self.data.feature_names, self.data.target_name) + if isinstance(model, xrf.rndmforest.RF2001): + self.random_forest = XRF(model, self.data.feature_names, self.data.target_name) + elif isinstance(model, RandomForestClassifier): + params = {'n_trees': model.n_estimators, + 'depth': model.max_depth} + cls = xrf.rndmforest.RF2001(**params) + train_accuracy, test_accuracy = cls.train(self.data) + self.random_forest = XRF(cls, self.data.feature_names, self.data.target_name) # encoding here so not in the explanation ? # visual self.tree_to_plot = 0 dot_source = tree.export_graphviz(self.random_forest.cls.estimators()[self.tree_to_plot], feature_names=self.data.feature_names, class_names=list(map(lambda cl : str(cl), self.data.target_name)), - impurity=False, filled=False, rounded=True) + impurity=False, filled=True, rounded=True) + dot_source = re.sub(r"(samples) \= [0-9]+", '', dot_source) + dot_source = re.sub(r"\\nvalue \= \[[\d+|\,|\s]+\]\\n", '', dot_source) + dot_source = re.sub(r"\\nclass \= \d+", '', dot_source) + self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( dot_source=dot_source, style={"width": "50%", "height": "80%", @@ -31,14 +47,16 @@ class RandomForestComponent: # init explanation self.explanation = [] + self.options_cont_expls = {} + self.options_expls = {} - def update_with_explicability(self, instances, enum_feats=None, xtype=None, solver=None): + def update_with_explicability(self, instances, enum_feats=None, xtypes=None, solver=None): r""" Called when an instance is upload or when you press the button "Submit for explanation" with advanced parameters. Args: instances : list - list of instance to explain enum_feats : ghost feature - xtype : ghost feature - solver : ghost feature + xtypes : types of explanation + solver : solver, only SAT available for the moment """ instances = [list(map(lambda feature: feature[1], instance)) for instance in instances] self.explanation = [] @@ -54,18 +72,18 @@ class RandomForestComponent: self.explanation.append(html.Hr()) # Call explanation - explanation_result = None - if isinstance(self.random_forest, XRF): - explanation_result = self.random_forest.explain(instance) - # Creating a clean and nice text component - for k in explanation_result.keys(): - self.explanation.append(html.H5(k)) - self.explanation.append(html.Hr()) - self.explanation.append(html.P(explanation_result[k])) - self.explanation.append(html.Hr()) + xtypes_trad = {"abd": " Abductive ", "con" : "Contrastive "} + for xtype in xtypes : + explanation_result = self.random_forest.explain(instance, xtype) + # Creating a clean and nice text component + for k in explanation_result.keys(): + self.explanation.append(html.H5(xtypes_trad[xtype] + k)) + self.explanation.append(html.Hr()) + self.explanation.append(html.P(explanation_result[k])) + self.explanation.append(html.Hr()) - del self.random_forest.enc - del self.random_forest.x + del self.random_forest.enc + del self.random_forest.x return [], [] @@ -78,7 +96,10 @@ class RandomForestComponent: self.tree_to_plot = tree_to_plot dot_source = tree.export_graphviz(self.random_forest.cls.estimators()[self.tree_to_plot], feature_names=self.data.feature_names, class_names=list(map(lambda cl : str(cl), self.data.target_name)), - impurity=False, filled=False, rounded=True) + impurity=False, filled=True, rounded=True) + dot_source = re.sub(r"(samples) \= [0-9]+", '', dot_source) + dot_source = re.sub(r"\\nvalue \= \[[\d+|\,|\s]+\]\\n", '', dot_source) + dot_source = re.sub(r"\\nclass \= \d+", '', dot_source) self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( dot_source=dot_source, style={"width": "50%", diff --git a/pages/application/application.py b/pages/application/application.py index c53d260ec6fba2e8d863efb3dc5e79325bff7bd5..ced07ee303faca2b26bae369f3669cecb3d52d1e 100644 --- a/pages/application/application.py +++ b/pages/application/application.py @@ -123,7 +123,7 @@ class View: inline=True), html.Hr()]) - self.solver = html.Div([html.Label("Choose the SAT solver : "), + self.solver = html.Div([html.Label("Choose the solver : "), html.P(), dcc.Dropdown(self.model.solvers, id='solver_sat')]) @@ -169,9 +169,10 @@ class View: self.tree_to_plot = html.Div(id="choosing_tree", hidden=True, children=[html.H5("Choose a tree to plot: "), - html.Div(children=[dcc.Slider(0, 50, 1, - value=0, - id='choice_tree')])]) + html.Div(children=[dcc.Slider(0, 20, 1, marks=None, + tooltip={"placement": "bottom", "always_visible": True}, + value=0, + id='choice_tree')])]) self.layout = dbc.Row([dbc.Col([self.sidebar], width=3, class_name="sidebar"),