From c34194aa67f3c36b893058983e066a1786bc23c4 Mon Sep 17 00:00:00 2001 From: Leahcimali <michaelbenalipro@gmail.com> Date: Tue, 29 Oct 2024 02:17:26 +0100 Subject: [PATCH] Correct mistakes of data transform for centralized --- src/utils_data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/utils_data.py b/src/utils_data.py index d0133dc..8d093bb 100644 --- a/src/utils_data.py +++ b/src/utils_data.py @@ -644,7 +644,7 @@ def centralize_data(list_clients: list, row_exp: dict) -> Tuple[DataLoader, Data import numpy as np # Define data augmentation transforms - train_transform, test_val_transform = data_transformation(row_exp) + train_transform, val_transform, test_transform = data_transformation(row_exp) # Concatenate training data from all clients x_train = np.concatenate([list_clients[id].train_test['x_train'] for id in range(len(list_clients))], axis=0) @@ -660,8 +660,8 @@ def centralize_data(list_clients: list, row_exp: dict) -> Tuple[DataLoader, Data # Create Custom Datasets train_dataset = CustomDataset(x_train, y_train, transform=train_transform) - val_dataset = CustomDataset(x_val, y_val, transform=test_val_transform) - test_dataset = CustomDataset(x_test, y_test, transform=test_val_transform) + val_dataset = CustomDataset(x_val, y_val, transform=val_transform) + test_dataset = CustomDataset(x_test, y_test, transform=test_transform) # Create DataLoaders train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) -- GitLab