Skip to content
Snippets Groups Projects
Commit c34194aa authored by Leahcimali's avatar Leahcimali
Browse files

Correct mistakes of data transform for centralized

parent 17c83ab4
Branches
No related tags found
No related merge requests found
......@@ -644,7 +644,7 @@ def centralize_data(list_clients: list, row_exp: dict) -> Tuple[DataLoader, Data
import numpy as np
# Define data augmentation transforms
train_transform, test_val_transform = data_transformation(row_exp)
train_transform, val_transform, test_transform = data_transformation(row_exp)
# Concatenate training data from all clients
x_train = np.concatenate([list_clients[id].train_test['x_train'] for id in range(len(list_clients))], axis=0)
......@@ -660,8 +660,8 @@ def centralize_data(list_clients: list, row_exp: dict) -> Tuple[DataLoader, Data
# Create Custom Datasets
train_dataset = CustomDataset(x_train, y_train, transform=train_transform)
val_dataset = CustomDataset(x_val, y_val, transform=test_val_transform)
test_dataset = CustomDataset(x_test, y_test, transform=test_val_transform)
val_dataset = CustomDataset(x_val, y_val, transform=val_transform)
test_dataset = CustomDataset(x_test, y_test, transform=test_transform)
# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment