From c3433fa9e6ea314bc004e82aa296fcaf0aa532a8 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Wed, 8 Jun 2022 12:08:24 +0200 Subject: [PATCH] cleaning instances --- .../DecisionTree/DecisionTreeComponent.py | 22 ++++---- pages/application/DecisionTree/utils/dtviz.py | 4 +- .../RandomForest/RandomForestComponent.py | 53 +++++++++---------- .../RandomForest/utils/xrf/xforest.py | 4 -- pages/application/application.py | 2 +- utils.py | 33 +++++++++--- 6 files changed, 63 insertions(+), 55 deletions(-) diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py index 23dc317..19c7904 100644 --- a/pages/application/DecisionTree/DecisionTreeComponent.py +++ b/pages/application/DecisionTree/DecisionTreeComponent.py @@ -4,10 +4,7 @@ from dash import html from pages.application.DecisionTree.utils.data import Data from pages.application.DecisionTree.utils.dtree import DecisionTree -from pages.application.DecisionTree.utils.dtviz import (visualize, - visualize_expl, - visualize_instance, - visualize_contrastive_expl) +from pages.application.DecisionTree.utils.dtviz import * from pages.application.DecisionTree.utils.upload_tree import UploadedDecisionTree @@ -116,19 +113,14 @@ class DecisionTreeComponent: def update_with_explicability(self, instance, enum, xtype, solver): + instance = instance[0] instance_translated = self.translate_instance(instance) + self.explanation = [] list_explanations_path = [] list_contrastive_explanations_path = [] explanation = self.dt.explain(instance_translated, enum=enum, xtype=xtype, solver=solver) - dot_source = visualize_instance(self.dt, instance_translated) - self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( - dot_source=dot_source, style={"width": "50%", - "height": "80%", - "background-color": "transparent"} - )]) - # Creating a clean and nice text component # instance plotting self.explanation.append(html.H4("Instance : \n")) @@ -147,6 +139,14 @@ class DecisionTreeComponent: list_explanations_path = explanation["List of path explanation(s)"] list_contrastive_explanations_path = explanation["List of path contrastive explanation(s)"] + # Create graph + dot_source = visualize_instance(self.dt, instance_translated) + self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( + dot_source=dot_source, style={"width": "50%", + "height": "80%", + "background-color": "transparent"} + )]) + return list_explanations_path, list_contrastive_explanations_path def draw_explanation(self, instance, expl): diff --git a/pages/application/DecisionTree/utils/dtviz.py b/pages/application/DecisionTree/utils/dtviz.py index 02aa847..5390899 100755 --- a/pages/application/DecisionTree/utils/dtviz.py +++ b/pages/application/DecisionTree/utils/dtviz.py @@ -32,13 +32,13 @@ def create_legend(G): legend.add_edge(edge) edge = pydot.Edge("e", "f") - edge.obj_dict['attributes']["label"] = "contrastive explanation" + edge.obj_dict['attributes']["label"] = "contrastive \n explanation" edge.obj_dict['attributes']["color"] = "red" edge.obj_dict['attributes']["style"] = "dashed" legend.add_edge(edge) edge = pydot.Edge("c", "d") - edge.obj_dict['attributes']["label"] = "instance with explanation" + edge.obj_dict['attributes']["label"] = "instance with \n explanation" edge.obj_dict['attributes']["color"] = "blue" edge.obj_dict['attributes']["style"] = "dashed" legend.add_edge(edge) diff --git a/pages/application/RandomForest/RandomForestComponent.py b/pages/application/RandomForest/RandomForestComponent.py index fc548d9..dc532d4 100644 --- a/pages/application/RandomForest/RandomForestComponent.py +++ b/pages/application/RandomForest/RandomForestComponent.py @@ -11,16 +11,14 @@ class RandomForestComponent: # Conversion model self.data = Dataset(info) - self.data.mapping_features() if info is not None and 'csv' in type_info: self.random_forest = XRF(model, self.data.feature_names, self.data.target_name) # encoding here so not in the explanation - self.tree_to_plot = 0 dot_source = tree.export_graphviz(self.random_forest.cls.estimators()[self.tree_to_plot], - feature_names=self.data.feature_names, class_names=self.data.class_names, + feature_names=self.data.feature_names, class_names=list(map(lambda cl : str(cl), self.data.target_name)), impurity=False, filled=False, rounded=True) self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( dot_source=dot_source, style={"width": "50%", @@ -29,41 +27,38 @@ class RandomForestComponent: )]) self.explanation = [] - def update_with_explicability(self, instance, enum_feats=None, xtype=None, solver=None): - - instance = instance[0] - if "=" in instance: - splitted_instance = [float(v.split('=')[1].strip()) for v in instance.split(',')] - else: - splitted_instance = [float(v.strip()) for v in instance.split(',')] + def update_with_explicability(self, instances, enum_feats=None, xtype=None, solver=None): - # Call instance + instances = [list(map(lambda feature: feature[1], instance)) for instance in instances] self.explanation = [] - self.explanation.append(html.H5("Instance")) - self.explanation.append(html.Hr()) - self.explanation.append( - html.P(str([tuple((self.data.feature_names[i], str(splitted_instance[i]))) for i in - range(len(splitted_instance) - 1)]) + " THEN " + str(tuple((self.data.target_name, str(splitted_instance[-1])))))) - self.explanation.append(html.Hr()) - - # Call explanation - explanation_result = None - if isinstance(self.random_forest, XRF): - explanation_result = self.random_forest.explain(splitted_instance) - list_explanations_path = [] - # Creating a clean and nice text component - for k in explanation_result.keys(): - self.explanation.append(html.H5(k)) + for instance in instances: + self.explanation.append(html.H4("Sample : ")) + # Call instance + self.explanation.append(html.H5("Instance")) self.explanation.append(html.Hr()) - self.explanation.append(html.P(explanation_result[k])) + self.explanation.append( + html.P(str([tuple((self.data.feature_names[i], str(instance[i]))) for i in + range(len(instance) - 1)]) + " THEN " + str( + tuple((self.data.target_name, str(instance[-1])))))) self.explanation.append(html.Hr()) - return list_explanations_path, [] + # Call explanation + explanation_result = None + if isinstance(self.random_forest, XRF): + explanation_result = self.random_forest.explain(instance) + # Creating a clean and nice text component + for k in explanation_result.keys(): + self.explanation.append(html.H5(k)) + self.explanation.append(html.Hr()) + self.explanation.append(html.P(explanation_result[k])) + self.explanation.append(html.Hr()) + + return [], [] def update_plotted_tree(self, tree_to_plot): self.tree_to_plot = tree_to_plot dot_source = tree.export_graphviz(self.random_forest.cls.estimators()[self.tree_to_plot], - feature_names=self.data.feature_names, class_names=self.data.class_names, + feature_names=self.data.feature_names, class_names=list(map(lambda cl : str(cl), self.data.target_name)), impurity=False, filled=False, rounded=True) self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( dot_source=dot_source, style={"width": "50%", diff --git a/pages/application/RandomForest/utils/xrf/xforest.py b/pages/application/RandomForest/utils/xrf/xforest.py index a15685c..749e3c4 100644 --- a/pages/application/RandomForest/utils/xrf/xforest.py +++ b/pages/application/RandomForest/utils/xrf/xforest.py @@ -43,10 +43,6 @@ class Dataset(Data): le.fit(samples[:, -1]) samples[:, -1] = le.transform(samples[:, -1]) self.class_names = le.classes_ - print(le.classes_) - print(samples[1:4, :]) - else : - self.class_names = np.unique(samples[:, -1]) samples = np.asarray(samples, dtype=np.float32) self.X = samples[:, 0: self.nb_features] diff --git a/pages/application/application.py b/pages/application/application.py index 2edf406..3f88c12 100644 --- a/pages/application/application.py +++ b/pages/application/application.py @@ -219,7 +219,7 @@ class View: self.tree_to_plot = html.Div(id="choosing_tree", hidden=True, children=[html.H5("Choose a tree to plot: "), - html.Div(children=[dcc.Slider(0, 100, 1, + html.Div(children=[dcc.Slider(0, 50, 1, value=0, id='choice_tree')])]) diff --git a/utils.py b/utils.py index d171f1b..67e9af7 100644 --- a/utils.py +++ b/utils.py @@ -10,6 +10,7 @@ from pages.application.RandomForest.utils import xrf from pages.application.RandomForest.utils.xrf import * sys.modules['xrf'] = xrf + def parse_contents_graph(contents, filename): content_type, content_string = contents.split(',') decoded = base64.b64decode(content_string) @@ -17,7 +18,7 @@ def parse_contents_graph(contents, filename): if '.pkl' in filename: try: data = joblib.load(io.BytesIO(decoded)) - except : + except: data = pickle.load(io.BytesIO(decoded)) elif '.txt' in filename: data = decoded.decode('utf-8').strip() @@ -47,31 +48,47 @@ def parse_contents_data(contents, filename): return data +def split_instance_according_to_format(instance, features_names=None): + if "=" in instance: + splitted_instance = [tuple((v.split('=')[0].strip(), float(v.split('=')[1].strip()))) for v in + instance.split(',')] + else: + if features_names: + splitted_instance = [tuple((features_names[i], float(instance.split(',')[i].strip()))) for i in + range(len(instance.split(',')))] + else: + splitted_instance = [tuple(("feature_{0}".format(i), float(instance.split(',')[i].strip()))) for i in + range(len(instance.split(',')))] + + return splitted_instance + + def parse_contents_instance(contents, filename): content_type, content_string = contents.split(',') decoded = base64.b64decode(content_string) try: if '.csv' in filename: data = decoded.decode('utf-8') - features_names, data = str(data).strip().split('\n')[:2] + features_names = str(data).strip().split('\n')[0] features_names = str(features_names).strip().split(',') - data = str(data).strip().split(',') - data = list(tuple([features_names[i], np.float32(data[i])]) for i in range(len(data))) + data = str(data).strip().split('\n')[1:] + data = list(map(lambda inst: split_instance_according_to_format(inst, features_names), data)) elif '.txt' in filename: data = decoded.decode('utf-8') - data = str(data).strip().split(',') - data = list(map(lambda i: tuple([i[0], np.float32(i[1])]), [i.split('=') for i in data])) + data = str(data).split('\n') + data = list(map(lambda inst: split_instance_according_to_format(inst), data)) elif '.json' in filename: data = decoded.decode('utf-8').strip() data = json.loads(data) - data = list(tuple(data.items())) + data = [tuple(data.items())] elif '.inst' in filename: data = decoded.decode('utf-8').strip() data = json.loads(data) - data = list(tuple(data.items())) + data = [tuple(data.items())] elif '.samples' in filename: decoded = decoded.decode('utf-8').strip() data = str(decoded).split('\n') + data = list(map(lambda inst: split_instance_according_to_format(inst), data)) except Exception as e: print(e) return html.Div([ -- GitLab