From 746520a8b41bf8415f599ee2f51b93770e1733c7 Mon Sep 17 00:00:00 2001 From: Leahcimali <michaelbenalipro@gmail.com> Date: Tue, 8 Oct 2024 11:07:31 +0200 Subject: [PATCH] Update data augmentation strategy --- src/utils_data.py | 52 +++++++---------------------------------------- 1 file changed, 7 insertions(+), 45 deletions(-) diff --git a/src/utils_data.py b/src/utils_data.py index d285737..8385860 100644 --- a/src/utils_data.py +++ b/src/utils_data.py @@ -209,8 +209,6 @@ class CustomDataset(Dataset): return image, label - - def data_preparation(client: Client, row_exp: dict) -> None: """Saves Dataloaders of train and test data in the Client attributes @@ -231,11 +229,12 @@ def data_preparation(client: Client, row_exp: dict) -> None: import numpy as np # Import NumPy for transpose operation # Define data augmentation transforms + train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), - AddGaussianNoise(0., 1.), - transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize if needed + transforms.RandomRotation(20), # Normalize if needed + transforms.RandomCrop(32, padding=4), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Transform for validation and test data (no augmentation, just normalization) @@ -611,43 +610,6 @@ def swap_labels(labels : list, client : Client, heterogeneity_class : int) -> Cl return client -''' -def centralize_data(list_clients : list) -> Tuple[DataLoader, DataLoader]: - """Centralize data of the federated learning setup for central model comparison - - Arguments: - list_clients : The list of Client Objects - - Returns: - Train and test torch DataLoaders with data of all Clients - """ - import torch - from torch.utils.data import DataLoader,TensorDataset - import numpy as np - - - x_train = np.concatenate([list_clients[id].train_test['x_train'] for id in range(len(list_clients))],axis = 0) - y_train = np.concatenate([list_clients[id].train_test['y_train'] for id in range(len(list_clients))],axis = 0) - x_train_tensor = torch.tensor(x_train, dtype=torch.float32) - y_train_tensor = torch.tensor(y_train, dtype=torch.long) - - x_val = np.concatenate([list_clients[id].train_test['x_val'] for id in range(len(list_clients))],axis = 0) - y_val = np.concatenate([list_clients[id].train_test['y_val'] for id in range(len(list_clients))],axis = 0) - x_val_tensor = torch.tensor(x_val, dtype=torch.float32) - y_val_tensor = torch.tensor(y_val, dtype=torch.long) - - x_test = np.concatenate([list_clients[id].train_test['x_test'] for id in range(len(list_clients))],axis = 0) - y_test = np.concatenate([list_clients[id].train_test['y_test'] for id in range(len(list_clients))],axis = 0) - x_test_tensor = torch.tensor(x_test, dtype=torch.float32) - y_test_tensor = torch.tensor(y_test, dtype=torch.long) - - train_loader = DataLoader(TensorDataset(x_train_tensor, y_train_tensor), batch_size=128, shuffle=True) - val_loader = DataLoader(TensorDataset(x_val_tensor, y_val_tensor), batch_size=128, shuffle=False) - test_loader = DataLoader(TensorDataset(x_test_tensor, y_test_tensor), batch_size=128, shuffle=False) - - return train_loader, val_loader, test_loader -''' - def centralize_data(list_clients: list) -> Tuple[DataLoader, DataLoader]: """Centralize data of the federated learning setup for central model comparison @@ -665,9 +627,9 @@ def centralize_data(list_clients: list) -> Tuple[DataLoader, DataLoader]: train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), - AddGaussianNoise(0., 1.), - transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize if needed + transforms.RandomRotation(20), # Normalize if needed + transforms.RandomCrop(32, padding=4), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Transform for validation and test data (no augmentation, just normalization) -- GitLab