diff --git a/paradise/main.py b/paradise/main.py index ee2b7a8eb0ed5787d0633783d01418487541fa34..90a007c80155b1a1f668becde12994db1047cb39 100755 --- a/paradise/main.py +++ b/paradise/main.py @@ -63,6 +63,11 @@ if __name__ == "__main__": help="Use Docker containers directly to run the algorithm. Allow to run algorithms without cloning repo", action="store_true" ) + parser.add_argument( + "--merge", + help="Whether the method should merge the correlation matrices or not.", + action="store_true" + ) # Load args args = parser.parse_args() @@ -157,7 +162,7 @@ if __name__ == "__main__": if args.split and args.task in ["train", "all"]: splitter = BaseSplitter(f"{INPUT_DIR}/{config_name}") - splitter.split_data(method=args.method) + splitter.split_data(method=args.method, merge=args.merge) # ================================================================================================================= # Train algorithm diff --git a/paradise/split/base.py b/paradise/split/base.py index 0f322987af4982d3f1f71aa7d289ebfea753f537..9facc06a670f3d1ed4de32df4236c5aa1cba3077 100644 --- a/paradise/split/base.py +++ b/paradise/split/base.py @@ -31,7 +31,7 @@ class BaseSplitter: self.output_path = f"{path}/splitting" os.makedirs(f"{path}/splitting", exist_ok=True) - def split_data(self, method="HDBSCAN"): + def split_data(self, method="HDBSCAN", merge=False): """ This method will be in charge of splitting data into subsystems. """ @@ -49,7 +49,7 @@ class BaseSplitter: # cluters from its coefficient max_silhouette = 0 best_clusters = None - x = self._compute_correlations(w_df) + x = self._compute_correlations(w_df, merge=merge) if "HDBSCAN" == method: model = HDBSCAN(min_cluster_size=2, allow_single_cluster=True, n_jobs=-1) @@ -83,7 +83,7 @@ class BaseSplitter: labels = np.bitwise_or.reduce(labels_df.drop(columns=drop).to_numpy(), axis=1, dtype=np.int32) pd.DataFrame(labels).to_csv(f"{self.data_path}/dataset_{i}_auto_split_labels.csv", index=False) - def _compute_correlations(self, data): + def _compute_correlations(self, data, merge=False): """ Compute the vector of correlation coefficients for each of the variable of the dataset. """ @@ -106,11 +106,14 @@ class BaseSplitter: x.append(np.abs(correlation_matrix)) - x = np.array(x) - x = np.mean(x, axis=0) + if merge: + x = np.array(x) + x = np.mean(x, axis=0) - sns.heatmap(x, annot=True, cmap="coolwarm")\ - .get_figure()\ - .savefig(f"{self.output_path}/dataset_final_matrix.png") + sns.heatmap(x, annot=True, cmap="coolwarm")\ + .get_figure()\ + .savefig(f"{self.output_path}/dataset_final_matrix.png") + else: + x = np.concatenate(x, axis=1) return x