From c34194aa67f3c36b893058983e066a1786bc23c4 Mon Sep 17 00:00:00 2001
From: Leahcimali <michaelbenalipro@gmail.com>
Date: Tue, 29 Oct 2024 02:17:26 +0100
Subject: [PATCH] Correct mistakes of data transform for centralized

---
 src/utils_data.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/utils_data.py b/src/utils_data.py
index d0133dc..8d093bb 100644
--- a/src/utils_data.py
+++ b/src/utils_data.py
@@ -644,7 +644,7 @@ def centralize_data(list_clients: list, row_exp: dict) -> Tuple[DataLoader, Data
     import numpy as np 
     
 # Define data augmentation transforms
-    train_transform, test_val_transform = data_transformation(row_exp)
+    train_transform, val_transform, test_transform = data_transformation(row_exp)
 
     # Concatenate training data from all clients
     x_train = np.concatenate([list_clients[id].train_test['x_train'] for id in range(len(list_clients))], axis=0)
@@ -660,8 +660,8 @@ def centralize_data(list_clients: list, row_exp: dict) -> Tuple[DataLoader, Data
 
     # Create Custom Datasets
     train_dataset = CustomDataset(x_train, y_train, transform=train_transform)
-    val_dataset = CustomDataset(x_val, y_val, transform=test_val_transform)
-    test_dataset = CustomDataset(x_test, y_test, transform=test_val_transform)
+    val_dataset = CustomDataset(x_val, y_val, transform=val_transform)
+    test_dataset = CustomDataset(x_test, y_test, transform=test_transform)
 
     # Create DataLoaders
     train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
-- 
GitLab