From 43a3caa32e72401adea3f9916a0cf551684d5859 Mon Sep 17 00:00:00 2001 From: Pierre LOTTE <pierrelotte.dev@gmail.com> Date: Wed, 9 Oct 2024 13:14:34 +0200 Subject: [PATCH] Add possibility to merge or not --- paradise/main.py | 7 ++++++- paradise/split/base.py | 19 +++++++++++-------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/paradise/main.py b/paradise/main.py index ee2b7a8..90a007c 100755 --- a/paradise/main.py +++ b/paradise/main.py @@ -63,6 +63,11 @@ if __name__ == "__main__": help="Use Docker containers directly to run the algorithm. Allow to run algorithms without cloning repo", action="store_true" ) + parser.add_argument( + "--merge", + help="Whether the method should merge the correlation matrices or not.", + action="store_true" + ) # Load args args = parser.parse_args() @@ -157,7 +162,7 @@ if __name__ == "__main__": if args.split and args.task in ["train", "all"]: splitter = BaseSplitter(f"{INPUT_DIR}/{config_name}") - splitter.split_data(method=args.method) + splitter.split_data(method=args.method, merge=args.merge) # ================================================================================================================= # Train algorithm diff --git a/paradise/split/base.py b/paradise/split/base.py index 0f32298..9facc06 100644 --- a/paradise/split/base.py +++ b/paradise/split/base.py @@ -31,7 +31,7 @@ class BaseSplitter: self.output_path = f"{path}/splitting" os.makedirs(f"{path}/splitting", exist_ok=True) - def split_data(self, method="HDBSCAN"): + def split_data(self, method="HDBSCAN", merge=False): """ This method will be in charge of splitting data into subsystems. """ @@ -49,7 +49,7 @@ class BaseSplitter: # cluters from its coefficient max_silhouette = 0 best_clusters = None - x = self._compute_correlations(w_df) + x = self._compute_correlations(w_df, merge=merge) if "HDBSCAN" == method: model = HDBSCAN(min_cluster_size=2, allow_single_cluster=True, n_jobs=-1) @@ -83,7 +83,7 @@ class BaseSplitter: labels = np.bitwise_or.reduce(labels_df.drop(columns=drop).to_numpy(), axis=1, dtype=np.int32) pd.DataFrame(labels).to_csv(f"{self.data_path}/dataset_{i}_auto_split_labels.csv", index=False) - def _compute_correlations(self, data): + def _compute_correlations(self, data, merge=False): """ Compute the vector of correlation coefficients for each of the variable of the dataset. """ @@ -106,11 +106,14 @@ class BaseSplitter: x.append(np.abs(correlation_matrix)) - x = np.array(x) - x = np.mean(x, axis=0) + if merge: + x = np.array(x) + x = np.mean(x, axis=0) - sns.heatmap(x, annot=True, cmap="coolwarm")\ - .get_figure()\ - .savefig(f"{self.output_path}/dataset_final_matrix.png") + sns.heatmap(x, annot=True, cmap="coolwarm")\ + .get_figure()\ + .savefig(f"{self.output_path}/dataset_final_matrix.png") + else: + x = np.concatenate(x, axis=1) return x -- GitLab