Skip to content
Snippets Groups Projects
Commit f242fc6b authored by leahcimali's avatar leahcimali
Browse files

Correct errors with linear model data shape

parent 0512b9c0
Branches
No related tags found
No related merge requests found
......@@ -52,22 +52,21 @@ def create_label_dict(dataset : str, nn_model : str) -> dict:
if dataset == "fashion-mnist":
fashion_mnist = torchvision.datasets.MNIST("datasets", download=True)
(x_data, y_data) = fashion_mnist.data, fashion_mnist.targets
if nn_model in ["convolutional"]:
x_data = x_data.unsqueeze(1)
x_data = x_data.unsqueeze(3)
elif dataset == 'mnist':
mnist = torchvision.datasets.MNIST("datasets", download=True)
(x_data, y_data) = mnist.data, mnist.targets
x_data = x_data.unsqueeze(3)
elif dataset == "cifar10":
cifar10 = torchvision.datasets.CIFAR10("datasets", download=True)
(x_data, y_data) = cifar10.data, cifar10.targets
elif dataset == 'kmnist':
kmnist = torchvision.datasets.KMNIST("datasets", download=True)
(x_data, y_data) = kmnist.load_data()
x_data = kmnist.data # This gives you the images
x_data = x_data.unsqueeze(3)
y_data = kmnist.targets # This gives you the labels
else:
sys.exit("Unrecognized dataset. Please make sure you are using one of the following ['mnist', fashion-mnist', 'kmnist']")
......@@ -184,7 +183,7 @@ class AddRandomJitter(object):
saturation = self.saturation, hue = self.hue)
return transform(tensor)
class CustomDataset(Dataset):
class CifarDataset(Dataset):
def __init__(self, data, labels, transform=None):
# Ensure data is in (N, H, W, C) format
......@@ -229,18 +228,26 @@ def data_preparation(client: Client, row_exp: dict) -> None:
import numpy as np # Import NumPy for transpose operation
# Define data augmentation transforms
train_transform = transforms.Compose([
if row_exp['dataset'] == 'cifar10':
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(20), # Normalize if needed
transforms.RandomCrop(32, padding=4),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Transform for validation and test data (no augmentation, just normalization)
test_val_transform = transforms.Compose([
])
# Transform for validation and test data (no augmentation, just normalization)
test_val_transform = transforms.Compose([
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize if needed
])
else :
train_transform = transforms.Compose([
transforms.Normalize((0.5,), (0.5,)), # Normalize if needed
])
# Transform for validation and test data (no augmentation, just normalization)
test_val_transform = transforms.Compose([
transforms.Normalize((0.5,), (0.5,)), # Normalize if needed
])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
......@@ -250,9 +257,9 @@ def data_preparation(client: Client, row_exp: dict) -> None:
# Create datasets with transformations
train_dataset = CustomDataset(x_train, y_train, transform=train_transform)
val_dataset = CustomDataset(x_val, y_val, transform=test_val_transform)
test_dataset = CustomDataset(x_test, y_test, transform=test_val_transform)
train_dataset = CifarDataset(x_train, y_train, transform=train_transform)
val_dataset = CifarDataset(x_val, y_val, transform=test_val_transform)
test_dataset = CifarDataset(x_test, y_test, transform=test_val_transform)
# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
......@@ -306,7 +313,7 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
"""
from src.models import GenericConvModel
from src.models import GenericConvModel,GenericLinearModel
from src.utils_fed import init_server_cluster
import torch
......@@ -320,7 +327,7 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
if row_exp['nn_model'] == "linear":
model_server = Server(GenericConvModel(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1]))
model_server = Server(GenericLinearModel(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1]))
elif row_exp['nn_model'] == "convolutional":
......@@ -651,9 +658,9 @@ def centralize_data(list_clients: list) -> Tuple[DataLoader, DataLoader]:
y_test = np.concatenate([list_clients[id].train_test['y_test'] for id in range(len(list_clients))], axis=0)
# Create Custom Datasets
train_dataset = CustomDataset(x_train, y_train, transform=train_transform)
val_dataset = CustomDataset(x_val, y_val, transform=test_val_transform)
test_dataset = CustomDataset(x_test, y_test, transform=test_val_transform)
train_dataset = CifarDataset(x_train, y_train, transform=train_transform)
val_dataset = CifarDataset(x_val, y_val, transform=test_val_transform)
test_dataset = CifarDataset(x_test, y_test, transform=test_val_transform)
# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment