from dash import html import dash_interactive_graphviz from sklearn import tree from pages.application.RandomForest.utils import xrf from pages.application.RandomForest.utils.xrf.xforest import XRF, Dataset class RandomForestComponent: def __init__(self, model, info=None, type_info=''): # Conversion model self.data = Dataset(info) if info is not None and 'csv' in type_info: self.random_forest = XRF(model, self.data.feature_names, self.data.target_name) # encoding here so not in the explanation self.tree_to_plot = 0 dot_source = tree.export_graphviz(self.random_forest.cls.estimators()[self.tree_to_plot], feature_names=self.data.feature_names, class_names=list(map(lambda cl : str(cl), self.data.target_name)), impurity=False, filled=False, rounded=True) self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( dot_source=dot_source, style={"width": "50%", "height": "80%", "background-color": "transparent"} )]) self.explanation = [] def update_with_explicability(self, instances, enum_feats=None, xtype=None, solver=None): instances = [list(map(lambda feature: feature[1], instance)) for instance in instances] self.explanation = [] for instance in instances: self.explanation.append(html.H4("Sample : ")) # Call instance self.explanation.append(html.H5("Instance")) self.explanation.append(html.Hr()) self.explanation.append( html.P(str([tuple((self.data.feature_names[i], str(instance[i]))) for i in range(len(instance) - 1)]) + " THEN " + str( tuple((self.data.target_name, str(instance[-1])))))) self.explanation.append(html.Hr()) # Call explanation explanation_result = None if isinstance(self.random_forest, XRF): explanation_result = self.random_forest.explain(instance) # Creating a clean and nice text component for k in explanation_result.keys(): self.explanation.append(html.H5(k)) self.explanation.append(html.Hr()) self.explanation.append(html.P(explanation_result[k])) self.explanation.append(html.Hr()) del self.random_forest.enc del self.random_forest.x return [], [] def update_plotted_tree(self, tree_to_plot): self.tree_to_plot = tree_to_plot dot_source = tree.export_graphviz(self.random_forest.cls.estimators()[self.tree_to_plot], feature_names=self.data.feature_names, class_names=list(map(lambda cl : str(cl), self.data.target_name)), impurity=False, filled=False, rounded=True) self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( dot_source=dot_source, style={"width": "50%", "height": "80%", "background-color": "transparent"} )])