diff --git a/src/utils_data.py b/src/utils_data.py index 1a0058d33746eb509b0bd323878436af448acb3b..db1d33e18bc3ee207b5f7c1b1957a5b3306d2a0c 100644 --- a/src/utils_data.py +++ b/src/utils_data.py @@ -52,22 +52,21 @@ def create_label_dict(dataset : str, nn_model : str) -> dict: if dataset == "fashion-mnist": fashion_mnist = torchvision.datasets.MNIST("datasets", download=True) (x_data, y_data) = fashion_mnist.data, fashion_mnist.targets - - if nn_model in ["convolutional"]: - x_data = x_data.unsqueeze(1) - + x_data = x_data.unsqueeze(3) elif dataset == 'mnist': mnist = torchvision.datasets.MNIST("datasets", download=True) (x_data, y_data) = mnist.data, mnist.targets - + x_data = x_data.unsqueeze(3) elif dataset == "cifar10": cifar10 = torchvision.datasets.CIFAR10("datasets", download=True) (x_data, y_data) = cifar10.data, cifar10.targets - - + elif dataset == 'kmnist': kmnist = torchvision.datasets.KMNIST("datasets", download=True) - (x_data, y_data) = kmnist.load_data() + x_data = kmnist.data # This gives you the images + x_data = x_data.unsqueeze(3) + y_data = kmnist.targets # This gives you the labels + else: sys.exit("Unrecognized dataset. Please make sure you are using one of the following ['mnist', fashion-mnist', 'kmnist']") @@ -184,7 +183,7 @@ class AddRandomJitter(object): saturation = self.saturation, hue = self.hue) return transform(tensor) -class CustomDataset(Dataset): +class CifarDataset(Dataset): def __init__(self, data, labels, transform=None): # Ensure data is in (N, H, W, C) format @@ -229,18 +228,26 @@ def data_preparation(client: Client, row_exp: dict) -> None: import numpy as np # Import NumPy for transpose operation # Define data augmentation transforms - - train_transform = transforms.Compose([ + if row_exp['dataset'] == 'cifar10': + train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomRotation(20), # Normalize if needed transforms.RandomCrop(32, padding=4), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) - ]) - - # Transform for validation and test data (no augmentation, just normalization) - test_val_transform = transforms.Compose([ + ]) + # Transform for validation and test data (no augmentation, just normalization) + test_val_transform = transforms.Compose([ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize if needed ]) + else : + train_transform = transforms.Compose([ + transforms.Normalize((0.5,), (0.5,)), # Normalize if needed + ]) + + # Transform for validation and test data (no augmentation, just normalization) + test_val_transform = transforms.Compose([ + transforms.Normalize((0.5,), (0.5,)), # Normalize if needed + ]) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -250,9 +257,9 @@ def data_preparation(client: Client, row_exp: dict) -> None: # Create datasets with transformations - train_dataset = CustomDataset(x_train, y_train, transform=train_transform) - val_dataset = CustomDataset(x_val, y_val, transform=test_val_transform) - test_dataset = CustomDataset(x_test, y_test, transform=test_val_transform) + train_dataset = CifarDataset(x_train, y_train, transform=train_transform) + val_dataset = CifarDataset(x_val, y_val, transform=test_val_transform) + test_dataset = CifarDataset(x_test, y_test, transform=test_val_transform) # Create DataLoaders train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) @@ -306,7 +313,7 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: """ - from src.models import GenericConvModel + from src.models import GenericConvModel,GenericLinearModel from src.utils_fed import init_server_cluster import torch @@ -320,7 +327,7 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: if row_exp['nn_model'] == "linear": - model_server = Server(GenericConvModel(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) + model_server = Server(GenericLinearModel(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) elif row_exp['nn_model'] == "convolutional": @@ -651,9 +658,9 @@ def centralize_data(list_clients: list) -> Tuple[DataLoader, DataLoader]: y_test = np.concatenate([list_clients[id].train_test['y_test'] for id in range(len(list_clients))], axis=0) # Create Custom Datasets - train_dataset = CustomDataset(x_train, y_train, transform=train_transform) - val_dataset = CustomDataset(x_val, y_val, transform=test_val_transform) - test_dataset = CustomDataset(x_test, y_test, transform=test_val_transform) + train_dataset = CifarDataset(x_train, y_train, transform=train_transform) + val_dataset = CifarDataset(x_val, y_val, transform=test_val_transform) + test_dataset = CifarDataset(x_test, y_test, transform=test_val_transform) # Create DataLoaders train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)