diff --git a/src/utils_data.py b/src/utils_data.py index d28573784f369aa7cadac94503ef193f085f4076..838586092434938865f6c91de08b3b3d8449d6b5 100644 --- a/src/utils_data.py +++ b/src/utils_data.py @@ -209,8 +209,6 @@ class CustomDataset(Dataset): return image, label - - def data_preparation(client: Client, row_exp: dict) -> None: """Saves Dataloaders of train and test data in the Client attributes @@ -231,11 +229,12 @@ def data_preparation(client: Client, row_exp: dict) -> None: import numpy as np # Import NumPy for transpose operation # Define data augmentation transforms + train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), - AddGaussianNoise(0., 1.), - transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize if needed + transforms.RandomRotation(20), # Normalize if needed + transforms.RandomCrop(32, padding=4), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Transform for validation and test data (no augmentation, just normalization) @@ -611,43 +610,6 @@ def swap_labels(labels : list, client : Client, heterogeneity_class : int) -> Cl return client -''' -def centralize_data(list_clients : list) -> Tuple[DataLoader, DataLoader]: - """Centralize data of the federated learning setup for central model comparison - - Arguments: - list_clients : The list of Client Objects - - Returns: - Train and test torch DataLoaders with data of all Clients - """ - import torch - from torch.utils.data import DataLoader,TensorDataset - import numpy as np - - - x_train = np.concatenate([list_clients[id].train_test['x_train'] for id in range(len(list_clients))],axis = 0) - y_train = np.concatenate([list_clients[id].train_test['y_train'] for id in range(len(list_clients))],axis = 0) - x_train_tensor = torch.tensor(x_train, dtype=torch.float32) - y_train_tensor = torch.tensor(y_train, dtype=torch.long) - - x_val = np.concatenate([list_clients[id].train_test['x_val'] for id in range(len(list_clients))],axis = 0) - y_val = np.concatenate([list_clients[id].train_test['y_val'] for id in range(len(list_clients))],axis = 0) - x_val_tensor = torch.tensor(x_val, dtype=torch.float32) - y_val_tensor = torch.tensor(y_val, dtype=torch.long) - - x_test = np.concatenate([list_clients[id].train_test['x_test'] for id in range(len(list_clients))],axis = 0) - y_test = np.concatenate([list_clients[id].train_test['y_test'] for id in range(len(list_clients))],axis = 0) - x_test_tensor = torch.tensor(x_test, dtype=torch.float32) - y_test_tensor = torch.tensor(y_test, dtype=torch.long) - - train_loader = DataLoader(TensorDataset(x_train_tensor, y_train_tensor), batch_size=128, shuffle=True) - val_loader = DataLoader(TensorDataset(x_val_tensor, y_val_tensor), batch_size=128, shuffle=False) - test_loader = DataLoader(TensorDataset(x_test_tensor, y_test_tensor), batch_size=128, shuffle=False) - - return train_loader, val_loader, test_loader -''' - def centralize_data(list_clients: list) -> Tuple[DataLoader, DataLoader]: """Centralize data of the federated learning setup for central model comparison @@ -665,9 +627,9 @@ def centralize_data(list_clients: list) -> Tuple[DataLoader, DataLoader]: train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), - AddGaussianNoise(0., 1.), - transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize if needed + transforms.RandomRotation(20), # Normalize if needed + transforms.RandomCrop(32, padding=4), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Transform for validation and test data (no augmentation, just normalization)