From 746520a8b41bf8415f599ee2f51b93770e1733c7 Mon Sep 17 00:00:00 2001
From: Leahcimali <michaelbenalipro@gmail.com>
Date: Tue, 8 Oct 2024 11:07:31 +0200
Subject: [PATCH] Update data augmentation strategy

---
 src/utils_data.py | 52 +++++++----------------------------------------
 1 file changed, 7 insertions(+), 45 deletions(-)

diff --git a/src/utils_data.py b/src/utils_data.py
index d285737..8385860 100644
--- a/src/utils_data.py
+++ b/src/utils_data.py
@@ -209,8 +209,6 @@ class CustomDataset(Dataset):
 
         return image, label
 
-
-
 def data_preparation(client: Client, row_exp: dict) -> None:
     """Saves Dataloaders of train and test data in the Client attributes 
     
@@ -231,11 +229,12 @@ def data_preparation(client: Client, row_exp: dict) -> None:
     import numpy as np  # Import NumPy for transpose operation
     
     # Define data augmentation transforms
+
     train_transform = transforms.Compose([
         transforms.RandomHorizontalFlip(),
-        AddGaussianNoise(0., 1.),
-        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
-        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize if needed
+        transforms.RandomRotation(20),  # Normalize if needed
+        transforms.RandomCrop(32, padding=4),
+        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])
     
     # Transform for validation and test data (no augmentation, just normalization)
@@ -611,43 +610,6 @@ def swap_labels(labels : list, client : Client, heterogeneity_class : int) -> Cl
     return client
 
 
-'''
-def centralize_data(list_clients : list) -> Tuple[DataLoader, DataLoader]:
-    """Centralize data of the federated learning setup for central model comparison
-
-    Arguments:
-        list_clients : The list of Client Objects
-
-    Returns:
-        Train and test torch DataLoaders with data of all Clients
-    """
-    import torch 
-    from torch.utils.data import DataLoader,TensorDataset
-    import numpy as np 
-    
-    
-    x_train = np.concatenate([list_clients[id].train_test['x_train'] for id in range(len(list_clients))],axis = 0)
-    y_train = np.concatenate([list_clients[id].train_test['y_train'] for id in range(len(list_clients))],axis = 0)
-    x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
-    y_train_tensor = torch.tensor(y_train, dtype=torch.long)
-
-    x_val = np.concatenate([list_clients[id].train_test['x_val'] for id in range(len(list_clients))],axis = 0)
-    y_val = np.concatenate([list_clients[id].train_test['y_val'] for id in range(len(list_clients))],axis = 0)
-    x_val_tensor = torch.tensor(x_val, dtype=torch.float32)
-    y_val_tensor = torch.tensor(y_val, dtype=torch.long)
-
-    x_test = np.concatenate([list_clients[id].train_test['x_test'] for id in range(len(list_clients))],axis = 0)
-    y_test = np.concatenate([list_clients[id].train_test['y_test'] for id in range(len(list_clients))],axis = 0)
-    x_test_tensor = torch.tensor(x_test, dtype=torch.float32)
-    y_test_tensor = torch.tensor(y_test, dtype=torch.long)
-    
-    train_loader = DataLoader(TensorDataset(x_train_tensor, y_train_tensor), batch_size=128, shuffle=True)
-    val_loader = DataLoader(TensorDataset(x_val_tensor, y_val_tensor), batch_size=128, shuffle=False)
-    test_loader = DataLoader(TensorDataset(x_test_tensor, y_test_tensor), batch_size=128, shuffle=False)
-    
-    return train_loader, val_loader, test_loader
-'''
-
 def centralize_data(list_clients: list) -> Tuple[DataLoader, DataLoader]:
     """Centralize data of the federated learning setup for central model comparison
 
@@ -665,9 +627,9 @@ def centralize_data(list_clients: list) -> Tuple[DataLoader, DataLoader]:
 
     train_transform = transforms.Compose([
         transforms.RandomHorizontalFlip(),
-        AddGaussianNoise(0., 1.),
-        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
-        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize if needed
+        transforms.RandomRotation(20),  # Normalize if needed
+        transforms.RandomCrop(32, padding=4),
+        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])
     
     # Transform for validation and test data (no augmentation, just normalization)
-- 
GitLab