Skip to content
Snippets Groups Projects
Commit 410202b8 authored by Omar El Rifai's avatar Omar El Rifai
Browse files

Allow model selection between linear and convolutional

parent 98f09846
Branches
No related tags found
No related merge requests found
...@@ -5,6 +5,7 @@ import click ...@@ -5,6 +5,7 @@ import click
@click.option('--exp_type', help="The experiment type to run") @click.option('--exp_type', help="The experiment type to run")
@click.option('--heterogeneity_type', help="The data heterogeneity to test (or dataset)") @click.option('--heterogeneity_type', help="The data heterogeneity to test (or dataset)")
@click.option('--dataset') @click.option('--dataset')
@click.option('--nn_model', help= "The training model to use ('linear (default) or 'convolutional')")
@click.option('--num_clients', type=int) @click.option('--num_clients', type=int)
@click.option('--num_samples_by_label', type=int) @click.option('--num_samples_by_label', type=int)
@click.option('--num_clusters', type=int) @click.option('--num_clusters', type=int)
...@@ -14,14 +15,14 @@ import click ...@@ -14,14 +15,14 @@ import click
def main_driver(exp_type, dataset, heterogeneity_type, num_clients, num_samples_by_label, num_clusters, centralized_epochs, federated_rounds, seed): def main_driver(exp_type, dataset, nn_model, heterogeneity_type, num_clients, num_samples_by_label, num_clusters, centralized_epochs, federated_rounds, seed):
from pathlib import Path from pathlib import Path
import pandas as pd import pandas as pd
from src.utils_data import setup_experiment, get_uid from src.utils_data import setup_experiment, get_uid
row_exp = pd.Series({"exp_type": exp_type, "dataset": dataset, "heterogeneity_type": heterogeneity_type, "num_clients": num_clients, row_exp = pd.Series({"exp_type": exp_type, "dataset": dataset, "nn_model" : nn_model, "heterogeneity_type": heterogeneity_type, "num_clients": num_clients,
"num_samples_by_label": num_samples_by_label, "num_clusters": num_clusters, "centralized_epochs": centralized_epochs, "num_samples_by_label": num_samples_by_label, "num_clusters": num_clusters, "centralized_epochs": centralized_epochs,
"federated_rounds": federated_rounds, "seed": seed}) "federated_rounds": federated_rounds, "seed": seed})
......
...@@ -3,10 +3,10 @@ import torch.nn as nn ...@@ -3,10 +3,10 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
class SimpleLinear2(nn.Module): class SimpleLinear(nn.Module):
""" Fully connected neural network with a single hidden layer of default size 200 and ReLU activations""" """ Fully connected neural network with a single hidden layer of default size 200 and ReLU activations"""
def __init__(self, h1=200): def __init__(self, in_size, n_channels):
""" Initialization function """ Initialization function
Arguments: Arguments:
...@@ -14,8 +14,9 @@ class SimpleLinear2(nn.Module): ...@@ -14,8 +14,9 @@ class SimpleLinear2(nn.Module):
Desired size of the hidden layer Desired size of the hidden layer
""" """
super().__init__() super().__init__()
self.fc1 = nn.Linear(28*28, h1) self.fc1 = nn.Linear(in_size*in_size,200)
self.fc2 = nn.Linear(h1, 10) self.fc2 = nn.Linear(200, 10)
self.in_size = in_size
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
...@@ -23,19 +24,19 @@ class SimpleLinear2(nn.Module): ...@@ -23,19 +24,19 @@ class SimpleLinear2(nn.Module):
Arguments: Arguments:
x : torch.Tensor x : torch.Tensor
input image of size 28 x 28 input image of size in_size x in_size
Returns: Returns:
log_softmax probabilities of the output layer log_softmax probabilities of the output layer
""" """
x = x.view(-1, 28 * 28) x = x.view(-1, self.in_size * self.in_size)
x = F.relu(self.fc1(x)) x = F.relu(self.fc1(x))
x = self.fc2(x) x = self.fc2(x)
return F.log_softmax(x, dim=1) return F.log_softmax(x, dim=1)
class SimpleLinear(nn.Module): class SimpleConv(nn.Module):
""" Convolutional neural network with 3 convolutional layers and one fully connected layer """ Convolutional neural network with 3 convolutional layers and one fully connected layer
""" """
...@@ -43,7 +44,7 @@ class SimpleLinear(nn.Module): ...@@ -43,7 +44,7 @@ class SimpleLinear(nn.Module):
def __init__(self, in_size, n_channels): def __init__(self, in_size, n_channels):
""" Initialization function """ Initialization function
""" """
super(SimpleLinear, self).__init__() super(SimpleConv, self).__init__()
self.conv1 = nn.Conv2d(n_channels, 16, 3, padding=1) self.conv1 = nn.Conv2d(n_channels, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1) self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
......
...@@ -28,38 +28,39 @@ def shuffle_list(list_samples : int, seed : int) -> list: ...@@ -28,38 +28,39 @@ def shuffle_list(list_samples : int, seed : int) -> list:
return shuffled_list return shuffled_list
def create_label_dict(dataset : str, nn_model : str) -> dict:
def create_label_dict(dataset : dict) -> dict:
"""Create a dictionary of dataset samples
""" Create a dictionary of dataset samples
Arguments: Arguments:
dataset: The name of the dataset to use ('fashion-mnist', 'mnist', or 'kmnist') dataset : The name of the dataset to use (e.g 'fashion-mnist', 'mnist', or 'cifar10')
nn_model : the training model type ('linear' or 'convolutional')
Returns: Returns:
A dictionary of data of the form {'x': [], 'y': []} label_dict : A dictionary of data of the form {'x': [], 'y': []}
Raises: Raises:
Error if the dataset name is unrecognized Error : if the dataset name is unrecognized
""" """
import sys import sys
import numpy as np import numpy as np
import torchvision import torchvision
from tensorflow.keras.datasets import mnist, fashion_mnist
from extra_keras_datasets import kmnist from extra_keras_datasets import kmnist
#import torchvision
if dataset == "fashion-mnist": if dataset == "fashion-mnist":
fashion_mnist = torchvision.datasets.MNIST("datasets", download=True) fashion_mnist = torchvision.datasets.MNIST("datasets", download=True)
(x_train, y_train) = fashion_mnist.data, fashion_mnist.targets (x_train, y_train) = fashion_mnist.data, fashion_mnist.targets
if nn_model == "convolutional":
x_train = x_train.unsqueeze(1)
elif dataset == 'mnist': elif dataset == 'mnist':
mnist = torchvision.datasets.MNIST("datasets", download=True) mnist = torchvision.datasets.MNIST("datasets", download=True)
(x_train, y_train) = mnist.data, mnist.targets (x_train, y_train) = mnist.data, mnist.targets
x_train = x_train.unsqueeze(1)
if nn_model == "convolutional":
x_train = x_train.unsqueeze(1)
elif dataset == "cifar10": elif dataset == "cifar10":
cifar10 = torchvision.datasets.CIFAR10("datasets", download=True) cifar10 = torchvision.datasets.CIFAR10("datasets", download=True)
...@@ -68,6 +69,9 @@ def create_label_dict(dataset : dict) -> dict: ...@@ -68,6 +69,9 @@ def create_label_dict(dataset : dict) -> dict:
elif dataset == 'kmnist': elif dataset == 'kmnist':
(x_train, y_train), _ = kmnist.load_data() (x_train, y_train), _ = kmnist.load_data()
if nn_model == "convolutional":
x_train = x_train.unsqueeze(1)
else: else:
sys.exit("Unrecognized dataset. Please make sure you are using one of the following ['mnist', fashion-mnist', 'kmnist']") sys.exit("Unrecognized dataset. Please make sure you are using one of the following ['mnist', fashion-mnist', 'kmnist']")
...@@ -83,14 +87,15 @@ def create_label_dict(dataset : dict) -> dict: ...@@ -83,14 +87,15 @@ def create_label_dict(dataset : dict) -> dict:
return label_dict return label_dict
def get_clients_data(num_clients : int, num_samples_by_label : int, dataset : dict, seed : int) -> dict: def get_clients_data(num_clients : int, num_samples_by_label : int, dataset : str, nn_model : str) -> dict:
"""Distribute a dataset evenly accross num_clients clients. Works with datasets with 10 labels """Distribute a dataset evenly accross num_clients clients. Works with datasets with 10 labels
Arguments: Arguments:
num_clients : Number of clients of interest num_clients : Number of clients of interest
num_samples_by_label : Number of samples of each labels by client num_samples_by_label : Number of samples of each labels by client
dataset: The name of the dataset to use (e.g 'fashion-mnist', 'mnist', or 'cifar10')
nn_model : the training model type ('linear' or 'convolutional')
Returns: Returns:
client_dataset : Dictionnary where each key correspond to a client index. The samples will be contained in the 'x' key and the target in 'y' key client_dataset : Dictionnary where each key correspond to a client index. The samples will be contained in the 'x' key and the target in 'y' key
...@@ -98,7 +103,7 @@ def get_clients_data(num_clients : int, num_samples_by_label : int, dataset : di ...@@ -98,7 +103,7 @@ def get_clients_data(num_clients : int, num_samples_by_label : int, dataset : di
import numpy as np import numpy as np
label_dict = create_label_dict(dataset) label_dict = create_label_dict(dataset, nn_model)
clients_dictionary = {} clients_dictionary = {}
client_dataset = {} client_dataset = {}
...@@ -133,17 +138,22 @@ def rotate_images(client: Client, rotation: int) -> None: ...@@ -133,17 +138,22 @@ def rotate_images(client: Client, rotation: int) -> None:
""" """
import numpy as np import numpy as np
from math import prod
images = client.data['x'] images = client.data['x']
if rotation >0 : if rotation > 0 :
rotated_images = [] rotated_images = []
for img in images: for img in images:
orig_shape = img.shape
img_flatten = img.flatten()
rotated_img = np.rot90(img, k=rotation//90) # Rotate image by specified angle rotated_img = np.rot90(img, k=rotation//90) # Rotate image by specified angle
rotated_img = rotated_img.reshape(*orig_shape)
rotated_images.append(rotated_img) rotated_images.append(rotated_img)
client.data['x'] = np.array(rotated_images) client.data['x'] = np.array(rotated_images)
...@@ -196,7 +206,6 @@ def data_preparation(client : Client, row_exp : dict) -> None: ...@@ -196,7 +206,6 @@ def data_preparation(client : Client, row_exp : dict) -> None:
return return
def get_dataset_heterogeneities(heterogeneity_type: str) -> dict: def get_dataset_heterogeneities(heterogeneity_type: str) -> dict:
""" """
...@@ -205,7 +214,7 @@ def get_dataset_heterogeneities(heterogeneity_type: str) -> dict: ...@@ -205,7 +214,7 @@ def get_dataset_heterogeneities(heterogeneity_type: str) -> dict:
Arguments: Arguments:
heterogeneity_type : The label of the heterogeneity scenario (labels-distribution-skew, concept-shift-on-labels, quantity-skew) heterogeneity_type : The label of the heterogeneity scenario (labels-distribution-skew, concept-shift-on-labels, quantity-skew)
Returns: Returns:
A dictionary of the form {<het>: []} where <het> is the applicable heterogeneity type dict_params: A dictionary of the form {<het>: []} where <het> is the applicable heterogeneity type
""" """
dict_params = {} dict_params = {}
...@@ -221,24 +230,22 @@ def get_dataset_heterogeneities(heterogeneity_type: str) -> dict: ...@@ -221,24 +230,22 @@ def get_dataset_heterogeneities(heterogeneity_type: str) -> dict:
dict_params['skews'] = [0.1,0.2,0.6,1] dict_params['skews'] = [0.1,0.2,0.6,1]
return dict_params return dict_params
def setup_experiment(row_exp: dict) -> Tuple[Server, list]: def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
""" """ Setup function to create and personalize client's data
Setup function to create and personalize client's data
Arguments: Arguments:
row_exp : The current experiment's global parameters row_exp : The current experiment's global parameters
Returns:
model_server : A nn model used the server in the FL protocol
list_clients : A list of Client Objects used as nodes in the FL protocol
Returns:
model_server, list_clients: a nn model used the server in the FL protocol, a list of Client Objects used as nodes in the FL protocol
""" """
from src.models import SimpleLinear from src.models import SimpleLinear, SimpleConv
from src.utils_fed import init_server_cluster from src.utils_fed import init_server_cluster
import torch import torch
...@@ -246,14 +253,20 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: ...@@ -246,14 +253,20 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
torch.manual_seed(row_exp['seed']) torch.manual_seed(row_exp['seed'])
imgs_params = {'mnist': (24,1) , 'fashion-mnist': (24,1), 'kmnist': (24,1), 'cifar10': (32,3)} imgs_params = {'mnist': (28,1) , 'fashion-mnist': (28,1), 'kmnist': (28,1), 'cifar10': (32,3)}
model_server = Server(SimpleLinear(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) if row_exp['nn_model'] == "linear":
model_server = Server(SimpleLinear(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1]))
elif row_exp['nn_model'] == "convolutional":
model_server = Server(SimpleConv(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1]))
dict_clients = get_clients_data(row_exp['num_clients'], dict_clients = get_clients_data(row_exp['num_clients'],
row_exp['num_samples_by_label'], row_exp['num_samples_by_label'],
row_exp['dataset'], row_exp['dataset'],
row_exp['seed']) row_exp['nn_model'])
for i in range(row_exp['num_clients']): for i in range(row_exp['num_clients']):
...@@ -262,7 +275,8 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: ...@@ -262,7 +275,8 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
list_clients = add_clients_heterogeneity(list_clients, row_exp) list_clients = add_clients_heterogeneity(list_clients, row_exp)
if row_exp['exp_type'] == "client": if row_exp['exp_type'] == "client":
init_server_cluster(model_server, list_clients, row_exp, imgs_params['dataset'])
init_server_cluster(model_server, list_clients, row_exp, imgs_params[row_exp['dataset']])
return model_server, list_clients return model_server, list_clients
...@@ -679,27 +693,6 @@ def erode_images(x_train : ndarray, kernel_size : tuple =(3, 3)) -> ndarray: ...@@ -679,27 +693,6 @@ def erode_images(x_train : ndarray, kernel_size : tuple =(3, 3)) -> ndarray:
return eroded_images return eroded_images
def save_results(model_server : Server, row_exp : dict ) -> None:
"""
Saves model_server in row_exp['output'] as *.pth object
Arguments:
model_server : The nn.Module to save
row_exp : The current experiment's global parameters
"""
import torch
if row_exp['exp_type'] == "client" or "server":
for cluster_id in range(row_exp['num_clusters']):
torch.save(model_server.clusters_models[cluster_id].state_dict(), f"./results/{row_exp['output']}_{row_exp['exp_type']}_model_cluster_{cluster_id}.pth")
return
def get_uid(str_obj: str) -> str: def get_uid(str_obj: str) -> str:
""" """
Generates an (almost) unique Identifier given a string object. Generates an (almost) unique Identifier given a string object.
......
...@@ -203,9 +203,9 @@ def summarize_results() -> None: ...@@ -203,9 +203,9 @@ def summarize_results() -> None:
list_params = path.stem.split('_') list_params = path.stem.split('_')
dict_exp_results = {"exp_type" : list_params[0], "dataset": list_params[1], "dataset_type": list_params[2], "number_of_clients": list_params[3], dict_exp_results = {"exp_type" : list_params[0], "dataset": list_params[1], "nn_model" : list_params[2], "dataset_type": list_params[3], "number_of_clients": list_params[4],
"samples by_client": list_params[4], "num_clusters": list_params[5], "centralized_epochs": list_params[6], "samples by_client": list_params[5], "num_clusters": list_params[6], "centralized_epochs": list_params[7],
"federated_rounds": list_params[7],"accuracy": accuracy} "federated_rounds": list_params[8],"accuracy": accuracy}
try: try:
......
...@@ -50,7 +50,7 @@ def run_cfl_server_side(model_server : Server, list_clients : list, row_exp : di ...@@ -50,7 +50,7 @@ def run_cfl_server_side(model_server : Server, list_clients : list, row_exp : di
return df_results return df_results
def run_cfl_client_side(model_server : Server, list_clients : list, row_exp : dict, init_cluster=True) -> pd.DataFrame: def run_cfl_client_side(model_server : Server, list_clients : list, row_exp : dict) -> pd.DataFrame:
""" Driver function for client-side cluster FL algorithm. The algorithm personalize training by clusters obtained """ Driver function for client-side cluster FL algorithm. The algorithm personalize training by clusters obtained
from model weights (k-means). from model weights (k-means).
...@@ -61,7 +61,6 @@ def run_cfl_client_side(model_server : Server, list_clients : list, row_exp : di ...@@ -61,7 +61,6 @@ def run_cfl_client_side(model_server : Server, list_clients : list, row_exp : di
main_model : Type of Server model needed main_model : Type of Server model needed
list_clients : A list of Client Objects used as nodes in the FL protocol list_clients : A list of Client Objects used as nodes in the FL protocol
row_exp : The current experiment's global parameters row_exp : The current experiment's global parameters
init_cluster : A boolean indicating whether to initialize cluster prior to training
""" """
from src.utils_fed import set_client_cluster, fedavg from src.utils_fed import set_client_cluster, fedavg
...@@ -112,7 +111,7 @@ def run_benchmark(main_model : nn.Module, list_clients : list, row_exp : dict) - ...@@ -112,7 +111,7 @@ def run_benchmark(main_model : nn.Module, list_clients : list, row_exp : dict) -
torch.manual_seed(row_exp['seed']) torch.manual_seed(row_exp['seed'])
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
curr_model = main_model if row_exp['exp_type'] == 'global-federated' else SimpleLinear() curr_model = main_model if row_exp['exp_type'] == 'global-federated' else main_model.model
match row_exp['exp_type']: match row_exp['exp_type']:
...@@ -202,7 +201,7 @@ def train_central(main_model, train_loader, row_exp): ...@@ -202,7 +201,7 @@ def train_central(main_model, train_loader, row_exp):
""" """
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer=optim.SGD optimizer=optim.SGD if row_exp['nn_model'] == "linear" else optim.Adam
optimizer = optimizer(main_model.parameters(), lr=0.01) optimizer = optimizer(main_model.parameters(), lr=0.01)
main_model.train() main_model.train()
......
id,cluster_id,heterogeneity_class,accuracy
0,1,none,80.66666666666666
1,1,erosion,57.666666666666664
2,1,dilatation,81.66666666666667
3,1,none,83.33333333333334
4,1,erosion,59.333333333333336
5,1,dilatation,81.0
id,cluster_id,heterogeneity_class,accuracy
0,,none,74.27777777777777
1,,erosion,74.27777777777777
2,,dilatation,74.27777777777777
3,,none,74.27777777777777
4,,erosion,74.27777777777777
5,,dilatation,74.27777777777777
id,cluster_id,heterogeneity_class,accuracy
0,,none,64.66666666666666
1,,erosion,39.5
2,,dilatation,79.83333333333333
3,,none,64.66666666666666
4,,erosion,39.5
5,,dilatation,79.83333333333333
id,cluster_id,heterogeneity_class,accuracy
0,2,none,85.66666666666667
1,1,erosion,66.66666666666666
2,0,dilatation,86.66666666666667
3,2,none,88.66666666666667
4,1,erosion,68.66666666666667
5,0,dilatation,89.0
...@@ -20,7 +20,7 @@ def utils_extract_params(file_path: Path): ...@@ -20,7 +20,7 @@ def utils_extract_params(file_path: Path):
with open (file_path, "r") as fp: with open (file_path, "r") as fp:
keys = ['exp_type', 'dataset' , 'heterogeneity_type' , 'num_clients', keys = ['exp_type', 'dataset', 'nn_model', 'heterogeneity_type' , 'num_clients',
'num_samples_by_label' , 'num_clusters', 'centralized_epochs', 'num_samples_by_label' , 'num_clusters', 'centralized_epochs',
'federated_rounds', 'seed'] 'federated_rounds', 'seed']
...@@ -29,7 +29,7 @@ def utils_extract_params(file_path: Path): ...@@ -29,7 +29,7 @@ def utils_extract_params(file_path: Path):
row_exp = dict( row_exp = dict(
zip(keys, zip(keys,
parameters[:3] + [int(x) for x in parameters[3:]]) parameters[:4] + [int(x) for x in parameters[4:]])
) )
return row_exp return row_exp
...@@ -44,7 +44,7 @@ def test_run_cfl_benchmark_oracle(): ...@@ -44,7 +44,7 @@ def test_run_cfl_benchmark_oracle():
from src.utils_data import setup_experiment from src.utils_data import setup_experiment
from src.utils_training import run_benchmark from src.utils_training import run_benchmark
file_path = Path("tests/refs/pers-centralized_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv") file_path = Path("tests/refs/pers-centralized_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
row_exp = utils_extract_params(file_path) row_exp = utils_extract_params(file_path)
...@@ -64,7 +64,7 @@ def test_run_cfl_benchmark_fl(): ...@@ -64,7 +64,7 @@ def test_run_cfl_benchmark_fl():
from src.utils_data import setup_experiment from src.utils_data import setup_experiment
from src.utils_training import run_benchmark from src.utils_training import run_benchmark
file_path = Path("tests/refs/global-federated_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv") file_path = Path("tests/refs/global-federated_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
row_exp = utils_extract_params(file_path) row_exp = utils_extract_params(file_path)
...@@ -84,7 +84,7 @@ def test_run_cfl_client_side(): ...@@ -84,7 +84,7 @@ def test_run_cfl_client_side():
from src.utils_data import setup_experiment from src.utils_data import setup_experiment
from src.utils_training import run_cfl_client_side from src.utils_training import run_cfl_client_side
file_path = Path("tests/refs/client_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv") file_path = Path("tests/refs/client_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
row_exp = utils_extract_params(file_path) row_exp = utils_extract_params(file_path)
...@@ -104,7 +104,7 @@ def test_run_cfl_server_side(): ...@@ -104,7 +104,7 @@ def test_run_cfl_server_side():
from src.utils_data import setup_experiment from src.utils_data import setup_experiment
from src.utils_training import run_cfl_server_side from src.utils_training import run_cfl_server_side
file_path = Path("tests/refs/server_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv") file_path = Path("tests/refs/server_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
row_exp = utils_extract_params(file_path) row_exp = utils_extract_params(file_path)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment