Skip to content
Snippets Groups Projects
Commit def20d90 authored by Leahcimali's avatar Leahcimali
Browse files

Occidata

parent 83348f60
Branches
No related tags found
No related merge requests found
*.py~ *.py~
results/* results/*
.vscode/* .vscode/*
src/__pycache__/* src/__pycache__/*
info.log info.log
launch.json launch.json
backup_results/* backup_results/*
*.sh *.sh
tests/__pycache__/* tests/__pycache__/*
datasets/* datasets/*
pub/* pub/*
data/* data/*
*.tgz *.tgz
*.pyc *.pyc
src/__pycache__/* src/__pycache__/*
\ No newline at end of file
#### Code for the paper: *Comparative Evaluation of Clustered Federated Learning Methods* #### Code for the paper: *Comparative Evaluation of Clustered Federated Learning Methods*
##### Submited to 'The 2nd IEEE International Conference on Federated Learning Technologies and Applications (FLTA24), VALENCIA, SPAIN' ##### Submited to 'The 2nd IEEE International Conference on Federated Learning Technologies and Applications (FLTA24), VALENCIA, SPAIN'
1. To reproduce the results in the paper run `driver.py` with the parameters in `exp_configs.csv` 1. To reproduce the results in the paper run `driver.py` with the parameters in `exp_configs.csv`
2. Each experiment will output a `.csv` file with the resuting metrics 2. Each experiment will output a `.csv` file with the resuting metrics
3. Histogram plots and a summary table of various experiments can be obtained running `src/utils_results.py` 3. Histogram plots and a summary table of various experiments can be obtained running `src/utils_results.py`
import os
# Set the environment variable for deterministic behavior with CuBLAS (Give reproductibility with CUDA)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import click import click
@click.command() @click.command()
@click.option('--exp_type', help="The experiment type to run") @click.option('--exp_type', help="The experiment type to run")
@click.option('--heterogeneity_type', help="The data heterogeneity to test (or dataset)")
@click.option('--dataset') @click.option('--dataset')
@click.option('--nn_model', help= "The training model to use ('linear (default) or 'convolutional')") @click.option('--nn_model', help= "The training model to use ('linear (default) or 'convolutional')")
@click.option('--heterogeneity_type', help="The data heterogeneity to test (or dataset)")
@click.option('--num_clients', type=int) @click.option('--num_clients', type=int)
@click.option('--num_samples_by_label', type=int) @click.option('--num_samples_by_label', type=int)
@click.option('--num_clusters', type=int) @click.option('--num_clusters', type=int)
...@@ -14,7 +18,6 @@ import click ...@@ -14,7 +18,6 @@ import click
@click.option('--seed', type=int) @click.option('--seed', type=int)
def main_driver(exp_type, dataset, nn_model, heterogeneity_type, num_clients, num_samples_by_label, num_clusters, centralized_epochs, federated_rounds, seed): def main_driver(exp_type, dataset, nn_model, heterogeneity_type, num_clients, num_samples_by_label, num_clusters, centralized_epochs, federated_rounds, seed):
from pathlib import Path from pathlib import Path
......
exp_type,dataset,nn_model,heterogeneity_type,num_clients,num_samples_by_label,num_clusters,centralized_epochs,federated_rounds,seed exp_type,dataset,nn_model,heterogeneity_type,num_clients,num_samples_by_label,num_clusters,centralized_epochs,federated_rounds,seed
server,cifar10,convolutional,concept-shift-on-labels,8,600,6,3,2,42 pers-centralized,cifar10,convolutional,concept-shift-on-features,48,100,4,200,0,42
global-federated,cifar10,convolutional,concept-shift-on-features,48,100,4,20,50,42
global-federated,cifar10,convolutional,concept-shift-on-features,48,100,4,20,100,42
global-federated,cifar10,convolutional,concept-shift-on-features,48,100,4,20,150,42
global-federated,cifar10,convolutional,concept-shift-on-features,48,100,4,20,200,42
global-federated,cifar10,convolutional,concept-shift-on-features,48,100,4,50,50,42
global-federated,cifar10,convolutional,concept-shift-on-features,48,100,4,50,100,42
global-federated,cifar10,convolutional,concept-shift-on-features,48,100,4,50,150,42
global-federated,cifar10,convolutional,concept-shift-on-features,48,100,4,50,200,42
\ No newline at end of file
# Automatically generated by https://github.com/damnever/pigar. # Automatically generated by https://github.com/damnever/pigar.
imbalanced-learn==0.12.3 imbalanced-learn==0.12.3
inputimeout==1.0.4 inputimeout==1.0.4
kiwisolver==1.4.5 kiwisolver==1.4.5
matplotlib==3.9.0 matplotlib==3.9.0
numpy==1.26.4 numpy==1.26.4
opencv-python==4.10.0.84 opencv-python==4.10.0.84
pandas==2.2.2 pandas==2.2.2
scikit-learn==1.5.0 scikit-learn==1.5.0
scipy==1.14.0 scipy==1.14.0
tensorflow==2.16.2 tensorflow==2.16.2
...@@ -2,23 +2,26 @@ import torch ...@@ -2,23 +2,26 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def accuracy(outputs, labels): def accuracy(outputs, labels):
_, preds = torch.max(outputs, dim=1) _, preds = torch.max(outputs, dim=1)
return torch.tensor(torch.sum(preds == labels).item() / len(preds)) return torch.tensor(torch.sum(preds == labels).item() / len(preds))
class ImageClassificationBase(nn.Module): class ImageClassificationBase(nn.Module):
def training_step(self, batch): def training_step(self, batch, device):
images, labels = batch images, labels = batch
images, labels = images.to(device), labels.to(device)
out = self(images) out = self(images)
loss = F.cross_entropy(out, labels) # Calculate loss loss = F.cross_entropy(out, labels)
return loss return loss
def validation_step(self, batch): def validation_step(self, batch, device):
images, labels = batch images, labels = batch
images, labels = images.to(device), labels.to(device)
out = self(images) out = self(images)
loss = F.cross_entropy(out, labels) # Calculate loss loss = F.cross_entropy(out, labels)
acc = accuracy(out, labels) acc = accuracy(out, labels)
return {'val_loss': loss.detach(), 'val_acc': acc} return {'val_loss': loss.detach(), 'val_acc': acc}
...@@ -32,60 +35,106 @@ class ImageClassificationBase(nn.Module): ...@@ -32,60 +35,106 @@ class ImageClassificationBase(nn.Module):
def epoch_end(self, epoch, result): def epoch_end(self, epoch, result):
print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format( print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
epoch, result['train_loss'], result['val_loss'], result['val_acc'])) epoch, result['train_loss'], result['val_loss'], result['val_acc']))
class GenericLinearModel(ImageClassificationBase): class GenericLinearModel(ImageClassificationBase):
def __init__(self, in_size, n_channels): def __init__(self, in_size, n_channels):
super().__init__() super().__init__()
self.in_size = in_size self.in_size = in_size
self.network = nn.Sequential( self.network = nn.Sequential(
nn.Linear(in_size*in_size,200), nn.Linear(in_size * in_size, 200),
nn.Linear(200, 10)) nn.Linear(200, 10)
)
def forward(self, xb): def forward(self, xb):
xb = xb.view(-1, self.in_size * self.in_size) xb = xb.view(-1, self.in_size * self.in_size)
return self.network(xb) return self.network(xb)
class GenericConvModel(ImageClassificationBase): class GenericConvModel(ImageClassificationBase):
def __init__(self, in_size, n_channels): def __init__(self, in_size, n_channels):
super().__init__() super().__init__()
self.img_final_size = int(in_size / (2**3)) self.img_final_size = int(in_size / (2**3))
self.network = nn.Sequential( self.network = nn.Sequential(
nn.Conv2d(n_channels, 32, kernel_size=3, padding=1), nn.Conv2d(n_channels, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(), nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(), nn.ReLU(),
nn.MaxPool2d(2, 2), # output: 64 x 16 x 16 nn.MaxPool2d(2, 2), # output: 64 x 16 x 16
nn.Dropout(0.25),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(), nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(), nn.ReLU(),
nn.MaxPool2d(2, 2), # output: 128 x 8 x 8 nn.MaxPool2d(2, 2), # output: 128 x 8 x 8
nn.Dropout(0.25),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(), nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(), nn.ReLU(),
nn.MaxPool2d(2, 2), # output: 256 x 4 x 4 nn.MaxPool2d(2, 2), # output: 256 x 4 x 4
nn.Dropout(0.25),
nn.Flatten(),
nn.Flatten(),
nn.Linear(256 * self.img_final_size * self.img_final_size, 1024), nn.Linear(256 * self.img_final_size * self.img_final_size, 1024),
nn.ReLU(), nn.ReLU(),
nn.Linear(1024, 512), nn.Linear(1024, 512),
nn.ReLU(), nn.ReLU(),
nn.Linear(512, 10)) nn.Linear(512, 10)
)
def forward(self, xb): def forward(self, xb):
return self.network(xb) return self.network(xb)
\ No newline at end of file
class CovNet(ImageClassificationBase):
def __init__(self, in_size, n_channels):
super().__init__()
self.img_final_size = int(in_size / (2**3))
self.network = nn.Sequential(
nn.Conv2d(n_channels, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2, 2), # output: 32 x 16 x 16
nn.Dropout(0.25),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2, 2), # output: 64 x 8 x 8
nn.Dropout(0.25),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2, 2), # output: 128 x 4 x 4
nn.Dropout(0.25),
nn.Flatten(),
nn.Linear(128 * self.img_final_size * self.img_final_size, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
def forward(self, xb):
return self.network(xb)
...@@ -3,35 +3,6 @@ from torch.utils.data import DataLoader ...@@ -3,35 +3,6 @@ from torch.utils.data import DataLoader
from numpy import ndarray from numpy import ndarray
from typing import Tuple from typing import Tuple
class AddGaussianNoise(object):
def __init__(self, mean=0., std=1.):
self.std = std
self.mean = mean
def __call__(self, tensor):
import torch
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class AddRandomJitter(object):
def __init__(self, brightness =0.5, contrast = 1, saturation = 0.1, hue = 0.5):
self.brightness = brightness,
self.contrast = contrast,
self.saturation = saturation,
self.hue = hue
def __call__(self, tensor):
import torchvision.transforms as transforms
transform = transforms.ColorJitter(brightness = self.brightness, contrast= self.contrast,
saturation = self.saturation, hue = self.hue)
return transform(tensor)
def shuffle_list(list_samples : int, seed : int) -> list: def shuffle_list(list_samples : int, seed : int) -> list:
"""Function to shuffle the samples list """Function to shuffle the samples list
...@@ -75,42 +46,28 @@ def create_label_dict(dataset : str, nn_model : str) -> dict: ...@@ -75,42 +46,28 @@ def create_label_dict(dataset : str, nn_model : str) -> dict:
import sys import sys
import numpy as np import numpy as np
import torchvision import torchvision
import torchvision.transforms as transforms import torchvision.transforms as transforms
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
AddGaussianNoise(0., 1.),
AddRandomJitter()])
if dataset == "fashion-mnist": if dataset == "fashion-mnist":
fashion_mnist = torchvision.datasets.MNIST("datasets", download=True, transform=transform) fashion_mnist = torchvision.datasets.MNIST("datasets", download=True)
(x_data, y_data) = fashion_mnist.data, fashion_mnist.targets (x_data, y_data) = fashion_mnist.data, fashion_mnist.targets
if nn_model == "convolutional": if nn_model in ["convolutional","CovNet"]:
x_data = x_data.unsqueeze(1) x_data = x_data.unsqueeze(1)
elif dataset == 'mnist': elif dataset == 'mnist':
mnist = torchvision.datasets.MNIST("datasets", download=True) mnist = torchvision.datasets.MNIST("datasets", download=True)
(x_data, y_data) = mnist.data, mnist.targets (x_data, y_data) = mnist.data, mnist.targets
if nn_model == "convolutional":
x_data = x_data.unsqueeze(1)
elif dataset == "cifar10": elif dataset == "cifar10":
cifar10 = torchvision.datasets.CIFAR10("datasets", download=True, transform=transform) cifar10 = torchvision.datasets.CIFAR10("datasets", download=True)
(x_data, y_data) = cifar10.data, cifar10.targets (x_data, y_data) = cifar10.data, cifar10.targets
x_data = np.transpose(x_data, (0, 3, 1, 2)) #x_data = np.transpose(x_data, (0, 3, 1, 2))
elif dataset == 'kmnist': elif dataset == 'kmnist':
kmnist = torchvision.datasets.KMNIST("datasets", download=True, transform=transform) kmnist = torchvision.datasets.KMNIST("datasets", download=True)
(x_data, y_data) = kmnist.load_data() (x_data, y_data) = kmnist.load_data()
if nn_model == "convolutional":
x_data = x_data.unsqueeze(1)
else: else:
sys.exit("Unrecognized dataset. Please make sure you are using one of the following ['mnist', fashion-mnist', 'kmnist']") sys.exit("Unrecognized dataset. Please make sure you are using one of the following ['mnist', fashion-mnist', 'kmnist']")
...@@ -194,10 +151,67 @@ def rotate_images(client: Client, rotation: int) -> None: ...@@ -194,10 +151,67 @@ def rotate_images(client: Client, rotation: int) -> None:
return return
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset, Dataset
import torchvision.transforms as transforms
class AddGaussianNoise(object):
def __init__(self, mean=0., std=1.):
self.std = std
self.mean = mean
def __call__(self, tensor):
import torch
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class AddRandomJitter(object):
def __init__(self, brightness =0.5, contrast = 1, saturation = 0.1, hue = 0.5):
self.brightness = brightness,
self.contrast = contrast,
self.saturation = saturation,
self.hue = hue
def __call__(self, tensor):
import torchvision.transforms as transforms
transform = transforms.ColorJitter(brightness = self.brightness, contrast= self.contrast,
saturation = self.saturation, hue = self.hue)
return transform(tensor)
def data_preparation(client : Client, row_exp : dict) -> None: class CustomDataset(Dataset):
def __init__(self, data, labels, transform=None):
# Ensure data is in (N, H, W, C) format
self.data = data # Assume data is in (N, H, W, C) format
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image = self.data[idx] # Shape (H, W, C)
label = self.labels[idx]
# Convert image to tensor and permute to (C, H, W)
image = torch.tensor(image, dtype=torch.float) # Convert to tensor
image = image.permute(2, 0, 1) # Change to (C, H, W)
# Apply transformation if provided
if self.transform:
image = self.transform(image)
return image, label
def data_preparation(client: Client, row_exp: dict) -> None:
"""Saves Dataloaders of train and test data in the Client attributes """Saves Dataloaders of train and test data in the Client attributes
Arguments: Arguments:
...@@ -206,37 +220,52 @@ def data_preparation(client : Client, row_exp : dict) -> None: ...@@ -206,37 +220,52 @@ def data_preparation(client : Client, row_exp : dict) -> None:
""" """
def to_device_tensor(data, device, data_dtype): def to_device_tensor(data, device, data_dtype):
data = torch.tensor(data, dtype=data_dtype) data = torch.tensor(data, dtype=data_dtype)
data.to(device) data = data.to(device)
return data return data
import torch import torch
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset from torch.utils.data import DataLoader, TensorDataset, Dataset
import torchvision.transforms as transforms
import numpy as np # Import NumPy for transpose operation
# Define data augmentation transforms
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
AddGaussianNoise(0., 1.),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize if needed
])
# Transform for validation and test data (no augmentation, just normalization)
test_val_transform = transforms.Compose([
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize if needed
])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x_data, x_test, y_data, y_test = train_test_split(client.data['x'], client.data['y'], test_size=0.3, random_state=row_exp['seed'],stratify=client.data['y']) # Split into train, validation, and test sets
x_train, x_val, y_train, y_val = train_test_split(x_data, y_data, test_size=0.25, random_state=42) x_data, x_test, y_data, y_test = train_test_split(client.data['x'], client.data['y'], test_size=0.3, random_state=row_exp['seed'], stratify=client.data['y'])
x_train, x_val, y_train, y_val = train_test_split(x_data, y_data, test_size=0.25, random_state=42)
x_train_tensor = to_device_tensor(x_train, device, torch.float32)
y_train_tensor = to_device_tensor(y_train, device, torch.long)
x_val_tensor = to_device_tensor(x_val, device, torch.float32) # Create datasets with transformations
y_val_tensor = to_device_tensor(y_val, device, torch.long) train_dataset = CustomDataset(x_train, y_train, transform=train_transform)
val_dataset = CustomDataset(x_val, y_val, transform=test_val_transform)
test_dataset = CustomDataset(x_test, y_test, transform=test_val_transform)
x_test_tensor = to_device_tensor(x_test, device, torch.float32) # Create DataLoaders
y_test_tensor = to_device_tensor(y_test, device, torch.long) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
train_loader = DataLoader(TensorDataset(x_train_tensor, y_train_tensor), batch_size=128, shuffle=True) # Store DataLoaders in the client object
validation_loader = DataLoader(TensorDataset(x_val_tensor, y_val_tensor), batch_size=128, shuffle=True) setattr(client, 'data_loader', {'train': train_loader, 'val': validation_loader, 'test': test_loader})
test_loader = DataLoader( TensorDataset(x_test_tensor, y_test_tensor), batch_size=128, shuffle = True) setattr(client, 'train_test', {'x_train': x_train, 'x_val': x_val, 'x_test': x_test, 'y_train': y_train, 'y_val': y_val, 'y_test': y_test})
return
setattr(client, 'data_loader', {'train' : train_loader, 'val' : validation_loader, 'test': test_loader, })
setattr(client,'train_test', {'x_train': x_train, 'x_val' : x_val, 'x_test': x_test, 'y_train': y_train, 'y_val': y_val, 'y_test': y_test})
return
def get_dataset_heterogeneities(heterogeneity_type: str) -> dict: def get_dataset_heterogeneities(heterogeneity_type: str) -> dict:
...@@ -278,7 +307,7 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: ...@@ -278,7 +307,7 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
""" """
from src.models import GenericConvModel from src.models import GenericConvModel, CovNet
from src.utils_fed import init_server_cluster from src.utils_fed import init_server_cluster
import torch import torch
...@@ -297,6 +326,9 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: ...@@ -297,6 +326,9 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
elif row_exp['nn_model'] == "convolutional": elif row_exp['nn_model'] == "convolutional":
model_server = Server(GenericConvModel(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) model_server = Server(GenericConvModel(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1]))
elif row_exp['nn_model'] == "CovNet":
model_server = Server(CovNet(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1]))
model_server.model.to(device) model_server.model.to(device)
...@@ -594,7 +626,8 @@ def centralize_data(list_clients : list) -> Tuple[DataLoader, DataLoader]: ...@@ -594,7 +626,8 @@ def centralize_data(list_clients : list) -> Tuple[DataLoader, DataLoader]:
import torch import torch
from torch.utils.data import DataLoader,TensorDataset from torch.utils.data import DataLoader,TensorDataset
import numpy as np import numpy as np
x_train = np.concatenate([list_clients[id].train_test['x_train'] for id in range(len(list_clients))],axis = 0) x_train = np.concatenate([list_clients[id].train_test['x_train'] for id in range(len(list_clients))],axis = 0)
y_train = np.concatenate([list_clients[id].train_test['y_train'] for id in range(len(list_clients))],axis = 0) y_train = np.concatenate([list_clients[id].train_test['y_train'] for id in range(len(list_clients))],axis = 0)
x_train_tensor = torch.tensor(x_train, dtype=torch.float32) x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
...@@ -616,6 +649,58 @@ def centralize_data(list_clients : list) -> Tuple[DataLoader, DataLoader]: ...@@ -616,6 +649,58 @@ def centralize_data(list_clients : list) -> Tuple[DataLoader, DataLoader]:
return train_loader, val_loader, test_loader return train_loader, val_loader, test_loader
def centralize_data(list_clients: list) -> Tuple[DataLoader, DataLoader]:
"""Centralize data of the federated learning setup for central model comparison
Arguments:
list_clients : The list of Client Objects
Returns:
Train and test torch DataLoaders with data of all Clients
"""
from torchvision import transforms
import torch
from torch.utils.data import DataLoader,TensorDataset
import numpy as np
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
AddGaussianNoise(0., 1.),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize if needed
])
# Transform for validation and test data (no augmentation, just normalization)
test_val_transform = transforms.Compose([
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize if needed
])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Concatenate training data from all clients
x_train = np.concatenate([list_clients[id].train_test['x_train'] for id in range(len(list_clients))], axis=0)
y_train = np.concatenate([list_clients[id].train_test['y_train'] for id in range(len(list_clients))], axis=0)
# Concatenate validation data from all clients
x_val = np.concatenate([list_clients[id].train_test['x_val'] for id in range(len(list_clients))], axis=0)
y_val = np.concatenate([list_clients[id].train_test['y_val'] for id in range(len(list_clients))], axis=0)
# Concatenate test data from all clients
x_test = np.concatenate([list_clients[id].train_test['x_test'] for id in range(len(list_clients))], axis=0)
y_test = np.concatenate([list_clients[id].train_test['y_test'] for id in range(len(list_clients))], axis=0)
# Create Custom Datasets
train_dataset = CustomDataset(x_train, y_train, transform=train_transform)
val_dataset = CustomDataset(x_val, y_val, transform=test_val_transform)
test_dataset = CustomDataset(x_test, y_test, transform=test_val_transform)
# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False) # Validation typically not shuffled
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) # Test data typically not shuffled
return train_loader, val_loader, test_loader
......
from src.fedclass import Server from src.fedclass import Server
import torch
import torch.nn as nn import torch.nn as nn
import pandas as pd import pandas as pd
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def send_server_model_to_client(list_clients : list, my_server : Server) -> None: def send_server_model_to_client(list_clients : list, my_server : Server) -> None:
""" Function to copy the Server model to client attributes in a FL protocol """ Function to copy the Server model to client attributes in a FL protocol
...@@ -120,12 +123,12 @@ def model_weight_matrix(list_clients : list) -> pd.DataFrame: ...@@ -120,12 +123,12 @@ def model_weight_matrix(list_clients : list) -> pd.DataFrame:
model_dict = {client.id : client.model for client in list_clients} model_dict = {client.id : client.model for client in list_clients}
shapes = [param.data.numpy().shape for param in next(iter(model_dict.values())).parameters()] shapes = [param.data.cpu().numpy().shape for param in next(iter(model_dict.values())).parameters()]
weight_matrix_np = np.empty((len(model_dict), sum(np.prod(shape) for shape in shapes))) weight_matrix_np = np.empty((len(model_dict), sum(np.prod(shape) for shape in shapes)))
for idx, (_, model) in enumerate(model_dict.items()): for idx, (_, model) in enumerate(model_dict.items()):
model_weights = np.concatenate([param.data.numpy().flatten() for param in model.parameters()]) model_weights = np.concatenate([param.data.cpu().numpy().flatten() for param in model.parameters()])
weight_matrix_np[idx, :] = model_weights weight_matrix_np[idx, :] = model_weights
weight_matrix = pd.DataFrame(weight_matrix_np, columns=[f'w_{i+1}' for i in range(weight_matrix_np.shape[1])]) weight_matrix = pd.DataFrame(weight_matrix_np, columns=[f'w_{i+1}' for i in range(weight_matrix_np.shape[1])])
...@@ -196,7 +199,7 @@ def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict, ...@@ -196,7 +199,7 @@ def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict,
p_expert_opintion : Parameter to avoid completly random assignment if neeed (default to 0) p_expert_opintion : Parameter to avoid completly random assignment if neeed (default to 0)
""" """
from src.models import GenericConvModel from src.models import GenericLinearModel, GenericConvModel, CovNet
import numpy as np import numpy as np
import copy import copy
...@@ -211,10 +214,8 @@ def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict, ...@@ -211,10 +214,8 @@ def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict,
p_rest = (1 - p_expert_opinion) / (row_exp['num_clusters'] - 1) p_rest = (1 - p_expert_opinion) / (row_exp['num_clusters'] - 1)
my_server.num_clusters = row_exp['num_clusters'] my_server.num_clusters = row_exp['num_clusters']
my_server.clusters_models = {cluster_id: GenericConvModel(in_size=imgs_params[0], n_channels=imgs_params[1]) for cluster_id in range(row_exp['num_clusters'])} my_server.clusters_models = {cluster_id: GenericConvModel(in_size=imgs_params[0], n_channels=imgs_params[1]) for cluster_id in range(row_exp['num_clusters'])}
for client in list_clients: for client in list_clients:
probs = [p_rest if x != list_heterogeneities.index(client.heterogeneity_class) % row_exp['num_clusters'] probs = [p_rest if x != list_heterogeneities.index(client.heterogeneity_class) % row_exp['num_clusters']
...@@ -241,7 +242,8 @@ def loss_calculation(model : nn.modules, train_loader : DataLoader) -> float: ...@@ -241,7 +242,8 @@ def loss_calculation(model : nn.modules, train_loader : DataLoader) -> float:
import torch.nn as nn import torch.nn as nn
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
model.to(device)
model.eval() model.eval()
total_loss = 0.0 total_loss = 0.0
......
id,cluster_id,heterogeneity_class,accuracy id,cluster_id,heterogeneity_class,accuracy
0,0,none,82.0 0,0,none,82.0
1,0,erosion,56.666666666666664 1,0,erosion,56.666666666666664
2,0,dilatation,82.33333333333334 2,0,dilatation,82.33333333333334
3,0,none,83.66666666666667 3,0,none,83.66666666666667
4,0,erosion,57.99999999999999 4,0,erosion,57.99999999999999
5,0,dilatation,83.0 5,0,dilatation,83.0
id,cluster_id,heterogeneity_class,accuracy id,cluster_id,heterogeneity_class,accuracy
0,,none,74.27777777777777 0,,none,74.27777777777777
1,,erosion,74.27777777777777 1,,erosion,74.27777777777777
2,,dilatation,74.27777777777777 2,,dilatation,74.27777777777777
3,,none,74.27777777777777 3,,none,74.27777777777777
4,,erosion,74.27777777777777 4,,erosion,74.27777777777777
5,,dilatation,74.27777777777777 5,,dilatation,74.27777777777777
id,cluster_id,heterogeneity_class,accuracy id,cluster_id,heterogeneity_class,accuracy
0,,none,65.33333333333333 0,,none,65.33333333333333
1,,erosion,39.5 1,,erosion,39.5
2,,dilatation,79.0 2,,dilatation,79.0
3,,none,65.33333333333333 3,,none,65.33333333333333
4,,erosion,39.5 4,,erosion,39.5
5,,dilatation,79.0 5,,dilatation,79.0
id,cluster_id,heterogeneity_class,accuracy id,cluster_id,heterogeneity_class,accuracy
0,2,none,85.66666666666667 0,2,none,85.66666666666667
1,1,erosion,66.66666666666666 1,1,erosion,66.66666666666666
2,0,dilatation,86.66666666666667 2,0,dilatation,86.66666666666667
3,2,none,88.66666666666667 3,2,none,88.66666666666667
4,1,erosion,68.66666666666667 4,1,erosion,68.66666666666667
5,0,dilatation,89.0 5,0,dilatation,89.0
import os import os
import pytest import pytest
from pathlib import Path from pathlib import Path
if os.getenv('_PYTEST_RAISE', "0") != "0": if os.getenv('_PYTEST_RAISE', "0") != "0":
@pytest.hookimpl(tryfirst=True) @pytest.hookimpl(tryfirst=True)
def pytest_exception_interact(call): def pytest_exception_interact(call):
raise call.excinfo.value raise call.excinfo.value
@pytest.hookimpl(tryfirst=True) @pytest.hookimpl(tryfirst=True)
def pytest_internalerror(excinfo): def pytest_internalerror(excinfo):
raise excinfo.value raise excinfo.value
def utils_extract_params(file_path: Path): def utils_extract_params(file_path: Path):
""" Creates a dictionary row_exp with the parameters for the experiment given a well formated results file path """ Creates a dictionary row_exp with the parameters for the experiment given a well formated results file path
""" """
with open (file_path, "r") as fp: with open (file_path, "r") as fp:
keys = ['exp_type', 'dataset', 'nn_model', 'heterogeneity_type' , 'num_clients', keys = ['exp_type', 'dataset', 'nn_model', 'heterogeneity_type' , 'num_clients',
'num_samples_by_label' , 'num_clusters', 'centralized_epochs', 'num_samples_by_label' , 'num_clusters', 'centralized_epochs',
'federated_rounds', 'seed'] 'federated_rounds', 'seed']
parameters = file_path.stem.split('_') parameters = file_path.stem.split('_')
row_exp = dict( row_exp = dict(
zip(keys, zip(keys,
parameters[:4] + [int(x) for x in parameters[4:]]) parameters[:4] + [int(x) for x in parameters[4:]])
) )
return row_exp return row_exp
def test_run_cfl_benchmark_oracle(): def test_run_cfl_benchmark_oracle():
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from src.utils_data import setup_experiment from src.utils_data import setup_experiment
from src.utils_training import run_benchmark from src.utils_training import run_benchmark
file_path = Path("tests/refs/pers-centralized_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv") file_path = Path("tests/refs/pers-centralized_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
row_exp = utils_extract_params(file_path) row_exp = utils_extract_params(file_path)
model_server, list_clients = setup_experiment(row_exp) model_server, list_clients = setup_experiment(row_exp)
df_results = run_benchmark(model_server, list_clients, row_exp) df_results = run_benchmark(model_server, list_clients, row_exp)
assert all(np.isclose(df_results['accuracy'], pd.read_csv(file_path)['accuracy'], rtol=0.01)) assert all(np.isclose(df_results['accuracy'], pd.read_csv(file_path)['accuracy'], rtol=0.01))
def test_run_cfl_benchmark_fl(): def test_run_cfl_benchmark_fl():
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from src.utils_data import setup_experiment from src.utils_data import setup_experiment
from src.utils_training import run_benchmark from src.utils_training import run_benchmark
file_path = Path("tests/refs/global-federated_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv") file_path = Path("tests/refs/global-federated_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
row_exp = utils_extract_params(file_path) row_exp = utils_extract_params(file_path)
model_server, list_clients = setup_experiment(row_exp) model_server, list_clients = setup_experiment(row_exp)
df_results = run_benchmark(model_server, list_clients, row_exp) df_results = run_benchmark(model_server, list_clients, row_exp)
assert all(np.isclose(df_results['accuracy'], pd.read_csv(file_path)['accuracy'], rtol=0.01)) assert all(np.isclose(df_results['accuracy'], pd.read_csv(file_path)['accuracy'], rtol=0.01))
def test_run_cfl_client_side(): def test_run_cfl_client_side():
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from src.utils_data import setup_experiment from src.utils_data import setup_experiment
from src.utils_training import run_cfl_client_side from src.utils_training import run_cfl_client_side
file_path = Path("tests/refs/client_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv") file_path = Path("tests/refs/client_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
row_exp = utils_extract_params(file_path) row_exp = utils_extract_params(file_path)
model_server, list_clients = setup_experiment(row_exp) model_server, list_clients = setup_experiment(row_exp)
df_results = run_cfl_client_side(model_server, list_clients, row_exp) df_results = run_cfl_client_side(model_server, list_clients, row_exp)
assert all(np.isclose(df_results['accuracy'], pd.read_csv(file_path)['accuracy'], rtol=0.01)) assert all(np.isclose(df_results['accuracy'], pd.read_csv(file_path)['accuracy'], rtol=0.01))
def test_run_cfl_server_side(): def test_run_cfl_server_side():
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from src.utils_data import setup_experiment from src.utils_data import setup_experiment
from src.utils_training import run_cfl_server_side from src.utils_training import run_cfl_server_side
file_path = Path("tests/refs/server_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv") file_path = Path("tests/refs/server_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
row_exp = utils_extract_params(file_path) row_exp = utils_extract_params(file_path)
model_server, list_clients = setup_experiment(row_exp) model_server, list_clients = setup_experiment(row_exp)
df_results = run_cfl_server_side(model_server, list_clients, row_exp) df_results = run_cfl_server_side(model_server, list_clients, row_exp)
assert all(np.isclose(df_results['accuracy'], pd.read_csv(file_path)['accuracy'], rtol=0.01)) assert all(np.isclose(df_results['accuracy'], pd.read_csv(file_path)['accuracy'], rtol=0.01))
if __name__ == "__main__": if __name__ == "__main__":
test_run_cfl_client_side() test_run_cfl_client_side()
test_run_cfl_server_side() test_run_cfl_server_side()
test_run_cfl_benchmark_fl() test_run_cfl_benchmark_fl()
test_run_cfl_benchmark_oracle() test_run_cfl_benchmark_oracle()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment