Skip to content
Snippets Groups Projects
Commit 015c3256 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

change in architecture for heroku

parent 04c432ad
Branches
No related tags found
No related merge requests found
...@@ -14,6 +14,9 @@ from pages.application.RandomForest.utils import xrf ...@@ -14,6 +14,9 @@ from pages.application.RandomForest.utils import xrf
from pages.application.RandomForest.utils.xrf import * from pages.application.RandomForest.utils.xrf import *
sys.modules['xrf'] = xrf sys.modules['xrf'] = xrf
from sklearn.ensemble._voting import VotingClassifier
from sklearn.ensemble import RandomForestClassifier
from pages.application.DecisionTree.DecisionTreeComponent import DecisionTreeComponent from pages.application.DecisionTree.DecisionTreeComponent import DecisionTreeComponent
from pages.application.RandomForest.RandomForestComponent import RandomForestComponent from pages.application.RandomForest.RandomForestComponent import RandomForestComponent
...@@ -30,12 +33,13 @@ def register_callbacks(app): ...@@ -30,12 +33,13 @@ def register_callbacks(app):
# For course directory # For course directory
main_course = dcc.Tabs(children=[ main_course = dcc.Tabs(children=[
dcc.Tab(label='Data format', children=[]), dcc.Tab(label='Data format', children=[]),
dcc.Tab(label='Course Decision Tree', children=[])]) dcc.Tab(label='Course Decision Tree', children=[]),
dcc.Tab(label='Course Random Forest', children=[])])
page_course = dbc.Row([main_course]) page_course = dbc.Row([main_course])
# For the application
models_data = open('data_retriever.json') models_data = open('data_retriever.json')
data = json.load(models_data)["data"] data = json.load(models_data)["data"]
# For the application
names_models, dict_components, dic_solvers, dic_xtypes = extract_data(data) names_models, dict_components, dic_solvers, dic_xtypes = extract_data(data)
# region alerts # region alerts
...@@ -43,7 +47,7 @@ def register_callbacks(app): ...@@ -43,7 +47,7 @@ def register_callbacks(app):
is_open=True, is_open=True,
color='warning', color='warning',
duration=10000, ), ]) duration=10000, ), ])
warning_selection_pretrained_model = html.Div([dbc.Alert("You uploaded the data, now upload the pretrained model.", warning_selection_pretrained_model = html.Div([dbc.Alert("Upload the pretrained model.",
is_open=True, is_open=True,
color='warning', color='warning',
duration=10000, ), ]) duration=10000, ), ])
...@@ -113,7 +117,7 @@ def register_callbacks(app): ...@@ -113,7 +117,7 @@ def register_callbacks(app):
xtypes = dic_xtypes[ml_type] xtypes = dic_xtypes[ml_type]
xtype = [list(xtypes.keys())[0]] xtype = [list(xtypes.keys())[0]]
return solvers, solver, xtypes, xtype return solvers, solver, xtypes, xtype
else : else:
return [], None, [], None return [], None, [], None
# endregion # endregion
...@@ -130,7 +134,7 @@ def register_callbacks(app): ...@@ -130,7 +134,7 @@ def register_callbacks(app):
ihm_id = ctx.triggered_id ihm_id = ctx.triggered_id
if ihm_id == 'ml_model_choice': if ihm_id == 'ml_model_choice':
return None return None
else : else:
return filename return filename
# endregion # endregion
...@@ -178,22 +182,21 @@ def register_callbacks(app): ...@@ -178,22 +182,21 @@ def register_callbacks(app):
@app.callback(Output('instance_filename', 'children'), @app.callback(Output('instance_filename', 'children'),
Input('ml_model_choice', 'value'), Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'), Input('ml_pretrained_model_choice', 'contents'),
Input('ml_instance_choice', 'contents'), Input('ml_instance_choice', 'filename'),
State('ml_instance_choice', 'filename'),
prevent_initial_call=True prevent_initial_call=True
) )
def select_instance(ml_type, model, instance, filename): def select_instance(ml_type, model, filename):
ctx = dash.callback_context ctx = dash.callback_context
if ctx.triggered: if ctx.triggered:
ihm_id = ctx.triggered_id ihm_id = ctx.triggered_id
if ihm_id == 'ml_model_choice' or ihm_id == "ml_pretrained_model_choice": if ihm_id == 'ml_model_choice' or ihm_id == "ml_pretrained_model_choice":
return None return None
else : else:
return filename return filename
# endregion # endregion
# region draw # region main
@app.callback( @app.callback(
Output('graph', 'children'), Output('graph', 'children'),
Output('explanation', 'children'), Output('explanation', 'children'),
...@@ -215,24 +218,31 @@ def register_callbacks(app): ...@@ -215,24 +218,31 @@ def register_callbacks(app):
Input('solver_sat', 'value'), Input('solver_sat', 'value'),
Input('choice_tree', 'value'), Input('choice_tree', 'value'),
prevent_initial_call=True) prevent_initial_call=True)
def draw_model(ml_type, pretrained_model, model_filename, need_data, data, instance, instance_filename, click, model_info_filename, def draw_model(ml_type, pretrained_model, model_filename, need_data, data, instance, instance_filename, click,
model_info_filename,
expl_choice, cont_expl_choice, enum, xtype, solver, id_tree): expl_choice, cont_expl_choice, enum, xtype, solver, id_tree):
ctx = dash.callback_context ctx = dash.callback_context
if ctx.triggered: if ctx.triggered:
ihm_id = ctx.triggered_id ihm_id = ctx.triggered_id
if ihm_id == "ml_model_choice" : # selecting a machine learning model type initialize the component
if ihm_id == "ml_model_choice":
return reinit, None, {}, {} return reinit, None, {}, {}
# uploading a model
elif ihm_id == "ml_pretrained_model_choice": elif ihm_id == "ml_pretrained_model_choice":
return init_network, None, {}, {} return init_network, None, {}, {}
else : else:
# construction of the component base from pretrained model
# catching exception if the construction of the component fails with try/except/else
try: try:
pretrained_model = parse_contents_graph(pretrained_model, model_filename)
if ml_type is None: if ml_type is None:
return warning_selection_model, None, {}, {} return warning_selection_model, None, {}, {}
elif pretrained_model is None :
return warning_selection_pretrained_model, None, {}, {}
else: else:
pretrained_model = parse_contents_graph(pretrained_model, model_filename)
component_class = dict_components[ml_type] component_class = dict_components[ml_type]
component_class = globals()[component_class] component_class = globals()[component_class]
if not need_data: if not need_data:
...@@ -241,40 +251,33 @@ def register_callbacks(app): ...@@ -241,40 +251,33 @@ def register_callbacks(app):
component = component_class(pretrained_model, info=data, type_info=model_info_filename) component = component_class(pretrained_model, info=data, type_info=model_info_filename)
else: else:
return warning_selection_data, None, {}, {} return warning_selection_data, None, {}, {}
except: except:
return alert_network, None, {}, {} return alert_network, None, {}, {}
else : else:
# plotting model by clicking "submit" button
if ihm_id == "submit-model": if ihm_id == "submit-model":
return component.network, None, {}, {} return component.network, None, {}, {}
# construction of explanation
if instance is not None:
try:
instance = parse_contents_instance(instance, instance_filename)
component.update_with_explicability(instance, enum, xtype, solver)
# In the case of DecisionTree, plotting explanation
if ihm_id == 'expl_choice':
component.draw_explanation(instance, expl_choice)
# # In the case of DecisionTree, plotting cont explanation
elif ihm_id == 'cont_expl_choice':
component.draw_contrastive_explanation(instance, cont_expl_choice)
except:
return component.network, alert_explanation, {}, {}
# In the case of RandomForest, id of tree to choose to draw tree # In the case of RandomForest, id of tree to choose to draw tree
elif ihm_id == 'choice_tree': if ihm_id == 'choice_tree':
component.update_plotted_tree(id_tree) component.update_plotted_tree(id_tree)
return component.network, component.explanation, {}, {}
return component.network, component.explanation, component.options_expls, component.options_cont_expls
elif instance is not None:
instance = parse_contents_instance(instance, instance_filename)
list_expls, list_cont_expls = component.update_with_explicability(instance, enum, xtype, solver)
options_expls = {}
options_cont_expls = {}
for i in range(len(list_expls)):
options_expls[str(list_expls[i])] = list_expls[i]
for i in range(len(list_cont_expls)):
options_cont_expls[str(list_cont_expls[i])] = list_cont_expls[i]
if ihm_id == 'expl_choice':
component.draw_explanation(instance, expl_choice)
return component.network, component.explanation, options_expls, options_cont_expls
# Choice of CxP to draw
elif ihm_id == 'cont_expl_choice':
component.draw_contrastive_explanation(instance, cont_expl_choice)
return component.network, component.explanation, options_expls, options_cont_expls
else :
return component.network, component.explanation, options_expls, options_cont_expls
# endregion # endregion
...@@ -284,17 +287,16 @@ def register_callbacks(app): ...@@ -284,17 +287,16 @@ def register_callbacks(app):
Output('explanation', 'hidden'), Output('explanation', 'hidden'),
Input('ml_model_choice', 'value'), Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'), Input('ml_pretrained_model_choice', 'contents'),
State('ml_instance_choice', 'contents'),
Input('submit-instance', 'n_clicks'), Input('submit-instance', 'n_clicks'),
prevent_initial_call=True prevent_initial_call=True
) )
def show_explanation_window(ml_type, model, instante, click): def show_explanation_window(ml_type, model, click):
ctx = dash.callback_context ctx = dash.callback_context
if ctx.triggered: if ctx.triggered:
ihm_id = ctx.triggered_id ihm_id = ctx.triggered_id
if ihm_id == "ml_model_choice" or ihm_id == "ml_pretrained_model_choice" or ihm_id == "ml_instance_choice": if ihm_id == "ml_model_choice" or ihm_id == "ml_pretrained_model_choice":
return True return True
else : else:
return False return False
# endregion # endregion
...@@ -309,6 +311,24 @@ def register_callbacks(app): ...@@ -309,6 +311,24 @@ def register_callbacks(app):
def choose_tree_in_forest(ml_type): def choose_tree_in_forest(ml_type):
return bool(ml_type != "RandomForest") return bool(ml_type != "RandomForest")
@app.callback(
Output("choice_tree", "max"),
State('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
State('ml_pretrained_model_choice', 'filename'),
prevent_initial_call=True)
def adjust_slider_max(ml_type, pretrained_model, model_filename):
if ml_type == "RandomForest":
pretrained_model = parse_contents_graph(pretrained_model, model_filename)
if isinstance(pretrained_model, xrf.rndmforest.RF2001):
return int(pretrained_model.forest.n_estimators)
elif isinstance(pretrained_model, RandomForestClassifier):
return pretrained_model.n_estimators
elif isinstance(pretrained_model, VotingClassifier):
return len(pretrained_model.estimators)
else:
return 0
# endregion # endregion
# region decisiontree # region decisiontree
...@@ -340,4 +360,4 @@ def register_callbacks(app): ...@@ -340,4 +360,4 @@ def register_callbacks(app):
def switcher_drawing_options(bool_draw): def switcher_drawing_options(bool_draw):
return not bool_draw return not bool_draw
# endregion # endregion
\ No newline at end of file
...@@ -8,21 +8,20 @@ ...@@ -8,21 +8,20 @@
"g3", "g4", "lgl", "mcb", "mcm", "mpl", "m22", "mc", "mgh" "g3", "g4", "lgl", "mcb", "mcm", "mpl", "m22", "mc", "mgh"
], ],
"xtypes" : { "xtypes" : {
"AXp": "Abductive Explanation", "CXp": "Contrastive explanation"} "AXp": " Abductive ", "CXp": " Contrastive "}
},
{
"ml_type" : "NaiveBayes",
"component" : "NaiveBayesComponent",
"solvers" : [],
"xtypes" : {
"AXp": "Abductive Explanation", "CXp": "Contrastive explanation"}
}, },
{ {
"ml_type" : "RandomForest", "ml_type" : "RandomForest",
"component" : "RandomForestComponent", "component" : "RandomForestComponent",
"solvers" : ["SAT"], "solvers" : ["SAT"],
"xtypes" : {"M": "Minimal explanation"} "xtypes" : {"abd": " Abductive ", "con" : " Contrastive "}
},
{
"ml_type" : "NaiveBayes",
"component" : "NaiveBayesComponent",
"solvers" : ["SAT"],
"xtypes" : {"abd": " Abductive ", "con" : " Contrastive "}
} }
] ]
} }
\ No newline at end of file
...@@ -62,6 +62,8 @@ class DecisionTreeComponent: ...@@ -62,6 +62,8 @@ class DecisionTreeComponent:
"background-color": "transparent"})]) "background-color": "transparent"})])
# init explanation # init explanation
self.explanation = [] self.explanation = []
self.options_cont_expls = {}
self.options_expls = {}
def create_fvmap_inverse(self, instance): def create_fvmap_inverse(self, instance):
def create_fvmap_inverse_with_info(features_names_mapping): def create_fvmap_inverse_with_info(features_names_mapping):
...@@ -158,7 +160,12 @@ class DecisionTreeComponent: ...@@ -158,7 +160,12 @@ class DecisionTreeComponent:
"background-color": "transparent"} "background-color": "transparent"}
)]) )])
return list_explanations_path, list_contrastive_explanations_path self.options_expls = {}
self.options_cont_expls = {}
for i in range(len(list_explanations_path)):
self.options_expls[str(list_explanations_path[i])] = list_explanations_path[i]
for i in range(len(list_contrastive_explanations_path)):
self.options_cont_expls[str(list_contrastive_explanations_path[i])] = list_contrastive_explanations_path[i]
def draw_explanation(self, instance, expl): def draw_explanation(self, instance, expl):
r""" Called with the selection of an explanation to plot on the tree r""" Called with the selection of an explanation to plot on the tree
......
import re
import numpy
from dash import html from dash import html
import dash_interactive_graphviz import dash_interactive_graphviz
from sklearn import tree from sklearn import tree
from pages.application.RandomForest.utils import xrf from pages.application.RandomForest.utils import xrf
from pages.application.RandomForest.utils.xrf.xforest import XRF, Dataset from pages.application.RandomForest.utils.xrf.xforest import XRF, Dataset
from sklearn.ensemble._voting import VotingClassifier
from sklearn.ensemble import RandomForestClassifier
class RandomForestComponent: class RandomForestComponent:
...@@ -15,14 +20,25 @@ class RandomForestComponent: ...@@ -15,14 +20,25 @@ class RandomForestComponent:
# creation of model # creation of model
if info is not None and 'csv' in type_info: if info is not None and 'csv' in type_info:
self.random_forest = XRF(model, self.data.feature_names, self.data.target_name) if isinstance(model, xrf.rndmforest.RF2001):
self.random_forest = XRF(model, self.data.feature_names, self.data.target_name)
elif isinstance(model, RandomForestClassifier):
params = {'n_trees': model.n_estimators,
'depth': model.max_depth}
cls = xrf.rndmforest.RF2001(**params)
train_accuracy, test_accuracy = cls.train(self.data)
self.random_forest = XRF(cls, self.data.feature_names, self.data.target_name)
# encoding here so not in the explanation ? # encoding here so not in the explanation ?
# visual # visual
self.tree_to_plot = 0 self.tree_to_plot = 0
dot_source = tree.export_graphviz(self.random_forest.cls.estimators()[self.tree_to_plot], dot_source = tree.export_graphviz(self.random_forest.cls.estimators()[self.tree_to_plot],
feature_names=self.data.feature_names, class_names=list(map(lambda cl : str(cl), self.data.target_name)), feature_names=self.data.feature_names, class_names=list(map(lambda cl : str(cl), self.data.target_name)),
impurity=False, filled=False, rounded=True) impurity=False, filled=True, rounded=True)
dot_source = re.sub(r"(samples) \= [0-9]+", '', dot_source)
dot_source = re.sub(r"\\nvalue \= \[[\d+|\,|\s]+\]\\n", '', dot_source)
dot_source = re.sub(r"\\nclass \= \d+", '', dot_source)
self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz(
dot_source=dot_source, style={"width": "50%", dot_source=dot_source, style={"width": "50%",
"height": "80%", "height": "80%",
...@@ -31,14 +47,16 @@ class RandomForestComponent: ...@@ -31,14 +47,16 @@ class RandomForestComponent:
# init explanation # init explanation
self.explanation = [] self.explanation = []
self.options_cont_expls = {}
self.options_expls = {}
def update_with_explicability(self, instances, enum_feats=None, xtype=None, solver=None): def update_with_explicability(self, instances, enum_feats=None, xtypes=None, solver=None):
r""" Called when an instance is upload or when you press the button "Submit for explanation" with advanced parameters. r""" Called when an instance is upload or when you press the button "Submit for explanation" with advanced parameters.
Args: Args:
instances : list - list of instance to explain instances : list - list of instance to explain
enum_feats : ghost feature enum_feats : ghost feature
xtype : ghost feature xtypes : types of explanation
solver : ghost feature solver : solver, only SAT available for the moment
""" """
instances = [list(map(lambda feature: feature[1], instance)) for instance in instances] instances = [list(map(lambda feature: feature[1], instance)) for instance in instances]
self.explanation = [] self.explanation = []
...@@ -54,18 +72,18 @@ class RandomForestComponent: ...@@ -54,18 +72,18 @@ class RandomForestComponent:
self.explanation.append(html.Hr()) self.explanation.append(html.Hr())
# Call explanation # Call explanation
explanation_result = None xtypes_trad = {"abd": " Abductive ", "con" : "Contrastive "}
if isinstance(self.random_forest, XRF): for xtype in xtypes :
explanation_result = self.random_forest.explain(instance) explanation_result = self.random_forest.explain(instance, xtype)
# Creating a clean and nice text component # Creating a clean and nice text component
for k in explanation_result.keys(): for k in explanation_result.keys():
self.explanation.append(html.H5(k)) self.explanation.append(html.H5(xtypes_trad[xtype] + k))
self.explanation.append(html.Hr()) self.explanation.append(html.Hr())
self.explanation.append(html.P(explanation_result[k])) self.explanation.append(html.P(explanation_result[k]))
self.explanation.append(html.Hr()) self.explanation.append(html.Hr())
del self.random_forest.enc del self.random_forest.enc
del self.random_forest.x del self.random_forest.x
return [], [] return [], []
...@@ -78,7 +96,10 @@ class RandomForestComponent: ...@@ -78,7 +96,10 @@ class RandomForestComponent:
self.tree_to_plot = tree_to_plot self.tree_to_plot = tree_to_plot
dot_source = tree.export_graphviz(self.random_forest.cls.estimators()[self.tree_to_plot], dot_source = tree.export_graphviz(self.random_forest.cls.estimators()[self.tree_to_plot],
feature_names=self.data.feature_names, class_names=list(map(lambda cl : str(cl), self.data.target_name)), feature_names=self.data.feature_names, class_names=list(map(lambda cl : str(cl), self.data.target_name)),
impurity=False, filled=False, rounded=True) impurity=False, filled=True, rounded=True)
dot_source = re.sub(r"(samples) \= [0-9]+", '', dot_source)
dot_source = re.sub(r"\\nvalue \= \[[\d+|\,|\s]+\]\\n", '', dot_source)
dot_source = re.sub(r"\\nclass \= \d+", '', dot_source)
self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz(
dot_source=dot_source, style={"width": "50%", dot_source=dot_source, style={"width": "50%",
......
...@@ -123,7 +123,7 @@ class View: ...@@ -123,7 +123,7 @@ class View:
inline=True), inline=True),
html.Hr()]) html.Hr()])
self.solver = html.Div([html.Label("Choose the SAT solver : "), self.solver = html.Div([html.Label("Choose the solver : "),
html.P(), html.P(),
dcc.Dropdown(self.model.solvers, dcc.Dropdown(self.model.solvers,
id='solver_sat')]) id='solver_sat')])
...@@ -169,9 +169,10 @@ class View: ...@@ -169,9 +169,10 @@ class View:
self.tree_to_plot = html.Div(id="choosing_tree", hidden=True, self.tree_to_plot = html.Div(id="choosing_tree", hidden=True,
children=[html.H5("Choose a tree to plot: "), children=[html.H5("Choose a tree to plot: "),
html.Div(children=[dcc.Slider(0, 50, 1, html.Div(children=[dcc.Slider(0, 20, 1, marks=None,
value=0, tooltip={"placement": "bottom", "always_visible": True},
id='choice_tree')])]) value=0,
id='choice_tree')])])
self.layout = dbc.Row([dbc.Col([self.sidebar], self.layout = dbc.Row([dbc.Col([self.sidebar],
width=3, class_name="sidebar"), width=3, class_name="sidebar"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment