Skip to content
Snippets Groups Projects
Commit c3433fa9 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

cleaning instances

parent a0b2d351
No related branches found
No related tags found
No related merge requests found
......@@ -4,10 +4,7 @@ from dash import html
from pages.application.DecisionTree.utils.data import Data
from pages.application.DecisionTree.utils.dtree import DecisionTree
from pages.application.DecisionTree.utils.dtviz import (visualize,
visualize_expl,
visualize_instance,
visualize_contrastive_expl)
from pages.application.DecisionTree.utils.dtviz import *
from pages.application.DecisionTree.utils.upload_tree import UploadedDecisionTree
......@@ -116,19 +113,14 @@ class DecisionTreeComponent:
def update_with_explicability(self, instance, enum, xtype, solver):
instance = instance[0]
instance_translated = self.translate_instance(instance)
self.explanation = []
list_explanations_path = []
list_contrastive_explanations_path = []
explanation = self.dt.explain(instance_translated, enum=enum, xtype=xtype, solver=solver)
dot_source = visualize_instance(self.dt, instance_translated)
self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz(
dot_source=dot_source, style={"width": "50%",
"height": "80%",
"background-color": "transparent"}
)])
# Creating a clean and nice text component
# instance plotting
self.explanation.append(html.H4("Instance : \n"))
......@@ -147,6 +139,14 @@ class DecisionTreeComponent:
list_explanations_path = explanation["List of path explanation(s)"]
list_contrastive_explanations_path = explanation["List of path contrastive explanation(s)"]
# Create graph
dot_source = visualize_instance(self.dt, instance_translated)
self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz(
dot_source=dot_source, style={"width": "50%",
"height": "80%",
"background-color": "transparent"}
)])
return list_explanations_path, list_contrastive_explanations_path
def draw_explanation(self, instance, expl):
......
......@@ -32,13 +32,13 @@ def create_legend(G):
legend.add_edge(edge)
edge = pydot.Edge("e", "f")
edge.obj_dict['attributes']["label"] = "contrastive explanation"
edge.obj_dict['attributes']["label"] = "contrastive \n explanation"
edge.obj_dict['attributes']["color"] = "red"
edge.obj_dict['attributes']["style"] = "dashed"
legend.add_edge(edge)
edge = pydot.Edge("c", "d")
edge.obj_dict['attributes']["label"] = "instance with explanation"
edge.obj_dict['attributes']["label"] = "instance with \n explanation"
edge.obj_dict['attributes']["color"] = "blue"
edge.obj_dict['attributes']["style"] = "dashed"
legend.add_edge(edge)
......
......@@ -11,16 +11,14 @@ class RandomForestComponent:
# Conversion model
self.data = Dataset(info)
self.data.mapping_features()
if info is not None and 'csv' in type_info:
self.random_forest = XRF(model, self.data.feature_names, self.data.target_name)
# encoding here so not in the explanation
self.tree_to_plot = 0
dot_source = tree.export_graphviz(self.random_forest.cls.estimators()[self.tree_to_plot],
feature_names=self.data.feature_names, class_names=self.data.class_names,
feature_names=self.data.feature_names, class_names=list(map(lambda cl : str(cl), self.data.target_name)),
impurity=False, filled=False, rounded=True)
self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz(
dot_source=dot_source, style={"width": "50%",
......@@ -29,41 +27,38 @@ class RandomForestComponent:
)])
self.explanation = []
def update_with_explicability(self, instance, enum_feats=None, xtype=None, solver=None):
instance = instance[0]
if "=" in instance:
splitted_instance = [float(v.split('=')[1].strip()) for v in instance.split(',')]
else:
splitted_instance = [float(v.strip()) for v in instance.split(',')]
def update_with_explicability(self, instances, enum_feats=None, xtype=None, solver=None):
# Call instance
instances = [list(map(lambda feature: feature[1], instance)) for instance in instances]
self.explanation = []
self.explanation.append(html.H5("Instance"))
self.explanation.append(html.Hr())
self.explanation.append(
html.P(str([tuple((self.data.feature_names[i], str(splitted_instance[i]))) for i in
range(len(splitted_instance) - 1)]) + " THEN " + str(tuple((self.data.target_name, str(splitted_instance[-1]))))))
self.explanation.append(html.Hr())
# Call explanation
explanation_result = None
if isinstance(self.random_forest, XRF):
explanation_result = self.random_forest.explain(splitted_instance)
list_explanations_path = []
# Creating a clean and nice text component
for k in explanation_result.keys():
self.explanation.append(html.H5(k))
for instance in instances:
self.explanation.append(html.H4("Sample : "))
# Call instance
self.explanation.append(html.H5("Instance"))
self.explanation.append(html.Hr())
self.explanation.append(html.P(explanation_result[k]))
self.explanation.append(
html.P(str([tuple((self.data.feature_names[i], str(instance[i]))) for i in
range(len(instance) - 1)]) + " THEN " + str(
tuple((self.data.target_name, str(instance[-1]))))))
self.explanation.append(html.Hr())
return list_explanations_path, []
# Call explanation
explanation_result = None
if isinstance(self.random_forest, XRF):
explanation_result = self.random_forest.explain(instance)
# Creating a clean and nice text component
for k in explanation_result.keys():
self.explanation.append(html.H5(k))
self.explanation.append(html.Hr())
self.explanation.append(html.P(explanation_result[k]))
self.explanation.append(html.Hr())
return [], []
def update_plotted_tree(self, tree_to_plot):
self.tree_to_plot = tree_to_plot
dot_source = tree.export_graphviz(self.random_forest.cls.estimators()[self.tree_to_plot],
feature_names=self.data.feature_names, class_names=self.data.class_names,
feature_names=self.data.feature_names, class_names=list(map(lambda cl : str(cl), self.data.target_name)),
impurity=False, filled=False, rounded=True)
self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz(
dot_source=dot_source, style={"width": "50%",
......
......@@ -43,10 +43,6 @@ class Dataset(Data):
le.fit(samples[:, -1])
samples[:, -1] = le.transform(samples[:, -1])
self.class_names = le.classes_
print(le.classes_)
print(samples[1:4, :])
else :
self.class_names = np.unique(samples[:, -1])
samples = np.asarray(samples, dtype=np.float32)
self.X = samples[:, 0: self.nb_features]
......
......@@ -219,7 +219,7 @@ class View:
self.tree_to_plot = html.Div(id="choosing_tree", hidden=True,
children=[html.H5("Choose a tree to plot: "),
html.Div(children=[dcc.Slider(0, 100, 1,
html.Div(children=[dcc.Slider(0, 50, 1,
value=0,
id='choice_tree')])])
......
......@@ -10,6 +10,7 @@ from pages.application.RandomForest.utils import xrf
from pages.application.RandomForest.utils.xrf import *
sys.modules['xrf'] = xrf
def parse_contents_graph(contents, filename):
content_type, content_string = contents.split(',')
decoded = base64.b64decode(content_string)
......@@ -17,7 +18,7 @@ def parse_contents_graph(contents, filename):
if '.pkl' in filename:
try:
data = joblib.load(io.BytesIO(decoded))
except :
except:
data = pickle.load(io.BytesIO(decoded))
elif '.txt' in filename:
data = decoded.decode('utf-8').strip()
......@@ -47,31 +48,47 @@ def parse_contents_data(contents, filename):
return data
def split_instance_according_to_format(instance, features_names=None):
if "=" in instance:
splitted_instance = [tuple((v.split('=')[0].strip(), float(v.split('=')[1].strip()))) for v in
instance.split(',')]
else:
if features_names:
splitted_instance = [tuple((features_names[i], float(instance.split(',')[i].strip()))) for i in
range(len(instance.split(',')))]
else:
splitted_instance = [tuple(("feature_{0}".format(i), float(instance.split(',')[i].strip()))) for i in
range(len(instance.split(',')))]
return splitted_instance
def parse_contents_instance(contents, filename):
content_type, content_string = contents.split(',')
decoded = base64.b64decode(content_string)
try:
if '.csv' in filename:
data = decoded.decode('utf-8')
features_names, data = str(data).strip().split('\n')[:2]
features_names = str(data).strip().split('\n')[0]
features_names = str(features_names).strip().split(',')
data = str(data).strip().split(',')
data = list(tuple([features_names[i], np.float32(data[i])]) for i in range(len(data)))
data = str(data).strip().split('\n')[1:]
data = list(map(lambda inst: split_instance_according_to_format(inst, features_names), data))
elif '.txt' in filename:
data = decoded.decode('utf-8')
data = str(data).strip().split(',')
data = list(map(lambda i: tuple([i[0], np.float32(i[1])]), [i.split('=') for i in data]))
data = str(data).split('\n')
data = list(map(lambda inst: split_instance_according_to_format(inst), data))
elif '.json' in filename:
data = decoded.decode('utf-8').strip()
data = json.loads(data)
data = list(tuple(data.items()))
data = [tuple(data.items())]
elif '.inst' in filename:
data = decoded.decode('utf-8').strip()
data = json.loads(data)
data = list(tuple(data.items()))
data = [tuple(data.items())]
elif '.samples' in filename:
decoded = decoded.decode('utf-8').strip()
data = str(decoded).split('\n')
data = list(map(lambda inst: split_instance_according_to_format(inst), data))
except Exception as e:
print(e)
return html.Div([
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment