diff --git a/src/utils_data.py b/src/utils_data.py index d0133dc0094c9067d127ad9f31b7d3c021b964c0..8d093bb1704b685f76a886605c22cf0cf5251cd6 100644 --- a/src/utils_data.py +++ b/src/utils_data.py @@ -644,7 +644,7 @@ def centralize_data(list_clients: list, row_exp: dict) -> Tuple[DataLoader, Data import numpy as np # Define data augmentation transforms - train_transform, test_val_transform = data_transformation(row_exp) + train_transform, val_transform, test_transform = data_transformation(row_exp) # Concatenate training data from all clients x_train = np.concatenate([list_clients[id].train_test['x_train'] for id in range(len(list_clients))], axis=0) @@ -660,8 +660,8 @@ def centralize_data(list_clients: list, row_exp: dict) -> Tuple[DataLoader, Data # Create Custom Datasets train_dataset = CustomDataset(x_train, y_train, transform=train_transform) - val_dataset = CustomDataset(x_val, y_val, transform=test_val_transform) - test_dataset = CustomDataset(x_test, y_test, transform=test_val_transform) + val_dataset = CustomDataset(x_val, y_val, transform=val_transform) + test_dataset = CustomDataset(x_test, y_test, transform=test_transform) # Create DataLoaders train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)