diff --git a/src/utils_data.py b/src/utils_data.py
index d0133dc0094c9067d127ad9f31b7d3c021b964c0..8d093bb1704b685f76a886605c22cf0cf5251cd6 100644
--- a/src/utils_data.py
+++ b/src/utils_data.py
@@ -644,7 +644,7 @@ def centralize_data(list_clients: list, row_exp: dict) -> Tuple[DataLoader, Data
     import numpy as np 
     
 # Define data augmentation transforms
-    train_transform, test_val_transform = data_transformation(row_exp)
+    train_transform, val_transform, test_transform = data_transformation(row_exp)
 
     # Concatenate training data from all clients
     x_train = np.concatenate([list_clients[id].train_test['x_train'] for id in range(len(list_clients))], axis=0)
@@ -660,8 +660,8 @@ def centralize_data(list_clients: list, row_exp: dict) -> Tuple[DataLoader, Data
 
     # Create Custom Datasets
     train_dataset = CustomDataset(x_train, y_train, transform=train_transform)
-    val_dataset = CustomDataset(x_val, y_val, transform=test_val_transform)
-    test_dataset = CustomDataset(x_test, y_test, transform=test_val_transform)
+    val_dataset = CustomDataset(x_val, y_val, transform=val_transform)
+    test_dataset = CustomDataset(x_test, y_test, transform=test_transform)
 
     # Create DataLoaders
     train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)