From 410202b8a85c7d6ab71929b2ee090ccbf199beeb Mon Sep 17 00:00:00 2001 From: Omar El Rifai <omar.void@gmail.com> Date: Tue, 3 Sep 2024 15:54:15 +0200 Subject: [PATCH] Allow model selection between linear and convolutional --- driver.py | 5 +- src/models.py | 17 +-- src/utils_data.py | 107 ++++++++---------- src/utils_results.py | 6 +- src/utils_training.py | 7 +- ...tures-distribution-skew_8_100_3_5_5_42.csv | 7 -- ...tures-distribution-skew_8_100_3_5_5_42.csv | 7 -- ...tures-distribution-skew_8_100_3_5_5_42.csv | 7 -- ...tures-distribution-skew_8_100_3_5_5_42.csv | 7 -- tests/test_utils_training.py | 12 +- 10 files changed, 74 insertions(+), 108 deletions(-) delete mode 100644 tests/refs/client_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv delete mode 100644 tests/refs/global-federated_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv delete mode 100644 tests/refs/pers-centralized_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv delete mode 100644 tests/refs/server_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv diff --git a/driver.py b/driver.py index d58195b..e28c422 100644 --- a/driver.py +++ b/driver.py @@ -5,6 +5,7 @@ import click @click.option('--exp_type', help="The experiment type to run") @click.option('--heterogeneity_type', help="The data heterogeneity to test (or dataset)") @click.option('--dataset') +@click.option('--nn_model', help= "The training model to use ('linear (default) or 'convolutional')") @click.option('--num_clients', type=int) @click.option('--num_samples_by_label', type=int) @click.option('--num_clusters', type=int) @@ -14,14 +15,14 @@ import click -def main_driver(exp_type, dataset, heterogeneity_type, num_clients, num_samples_by_label, num_clusters, centralized_epochs, federated_rounds, seed): +def main_driver(exp_type, dataset, nn_model, heterogeneity_type, num_clients, num_samples_by_label, num_clusters, centralized_epochs, federated_rounds, seed): from pathlib import Path import pandas as pd from src.utils_data import setup_experiment, get_uid - row_exp = pd.Series({"exp_type": exp_type, "dataset": dataset, "heterogeneity_type": heterogeneity_type, "num_clients": num_clients, + row_exp = pd.Series({"exp_type": exp_type, "dataset": dataset, "nn_model" : nn_model, "heterogeneity_type": heterogeneity_type, "num_clients": num_clients, "num_samples_by_label": num_samples_by_label, "num_clusters": num_clusters, "centralized_epochs": centralized_epochs, "federated_rounds": federated_rounds, "seed": seed}) diff --git a/src/models.py b/src/models.py index 54d82f1..3acd1e1 100644 --- a/src/models.py +++ b/src/models.py @@ -3,10 +3,10 @@ import torch.nn as nn import torch.nn.functional as F -class SimpleLinear2(nn.Module): +class SimpleLinear(nn.Module): """ Fully connected neural network with a single hidden layer of default size 200 and ReLU activations""" - def __init__(self, h1=200): + def __init__(self, in_size, n_channels): """ Initialization function Arguments: @@ -14,8 +14,9 @@ class SimpleLinear2(nn.Module): Desired size of the hidden layer """ super().__init__() - self.fc1 = nn.Linear(28*28, h1) - self.fc2 = nn.Linear(h1, 10) + self.fc1 = nn.Linear(in_size*in_size,200) + self.fc2 = nn.Linear(200, 10) + self.in_size = in_size def forward(self, x: torch.Tensor): @@ -23,19 +24,19 @@ class SimpleLinear2(nn.Module): Arguments: x : torch.Tensor - input image of size 28 x 28 + input image of size in_size x in_size Returns: log_softmax probabilities of the output layer """ - x = x.view(-1, 28 * 28) + x = x.view(-1, self.in_size * self.in_size) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1) -class SimpleLinear(nn.Module): +class SimpleConv(nn.Module): """ Convolutional neural network with 3 convolutional layers and one fully connected layer """ @@ -43,7 +44,7 @@ class SimpleLinear(nn.Module): def __init__(self, in_size, n_channels): """ Initialization function """ - super(SimpleLinear, self).__init__() + super(SimpleConv, self).__init__() self.conv1 = nn.Conv2d(n_channels, 16, 3, padding=1) self.conv2 = nn.Conv2d(16, 32, 3, padding=1) diff --git a/src/utils_data.py b/src/utils_data.py index 0935676..8cee10f 100644 --- a/src/utils_data.py +++ b/src/utils_data.py @@ -28,38 +28,39 @@ def shuffle_list(list_samples : int, seed : int) -> list: return shuffled_list - -def create_label_dict(dataset : dict) -> dict: - - """ Create a dictionary of dataset samples +def create_label_dict(dataset : str, nn_model : str) -> dict: + + """Create a dictionary of dataset samples Arguments: - dataset: The name of the dataset to use ('fashion-mnist', 'mnist', or 'kmnist') - + dataset : The name of the dataset to use (e.g 'fashion-mnist', 'mnist', or 'cifar10') + nn_model : the training model type ('linear' or 'convolutional') + Returns: - A dictionary of data of the form {'x': [], 'y': []} + label_dict : A dictionary of data of the form {'x': [], 'y': []} Raises: - Error if the dataset name is unrecognized - + Error : if the dataset name is unrecognized """ + import sys import numpy as np - import torchvision - from tensorflow.keras.datasets import mnist, fashion_mnist from extra_keras_datasets import kmnist - #import torchvision - if dataset == "fashion-mnist": fashion_mnist = torchvision.datasets.MNIST("datasets", download=True) (x_train, y_train) = fashion_mnist.data, fashion_mnist.targets + if nn_model == "convolutional": + x_train = x_train.unsqueeze(1) + elif dataset == 'mnist': mnist = torchvision.datasets.MNIST("datasets", download=True) (x_train, y_train) = mnist.data, mnist.targets - x_train = x_train.unsqueeze(1) + + if nn_model == "convolutional": + x_train = x_train.unsqueeze(1) elif dataset == "cifar10": cifar10 = torchvision.datasets.CIFAR10("datasets", download=True) @@ -68,6 +69,9 @@ def create_label_dict(dataset : dict) -> dict: elif dataset == 'kmnist': (x_train, y_train), _ = kmnist.load_data() + + if nn_model == "convolutional": + x_train = x_train.unsqueeze(1) else: sys.exit("Unrecognized dataset. Please make sure you are using one of the following ['mnist', fashion-mnist', 'kmnist']") @@ -83,14 +87,15 @@ def create_label_dict(dataset : dict) -> dict: return label_dict -def get_clients_data(num_clients : int, num_samples_by_label : int, dataset : dict, seed : int) -> dict: +def get_clients_data(num_clients : int, num_samples_by_label : int, dataset : str, nn_model : str) -> dict: """Distribute a dataset evenly accross num_clients clients. Works with datasets with 10 labels Arguments: num_clients : Number of clients of interest - num_samples_by_label : Number of samples of each labels by client + dataset: The name of the dataset to use (e.g 'fashion-mnist', 'mnist', or 'cifar10') + nn_model : the training model type ('linear' or 'convolutional') Returns: client_dataset : Dictionnary where each key correspond to a client index. The samples will be contained in the 'x' key and the target in 'y' key @@ -98,7 +103,7 @@ def get_clients_data(num_clients : int, num_samples_by_label : int, dataset : di import numpy as np - label_dict = create_label_dict(dataset) + label_dict = create_label_dict(dataset, nn_model) clients_dictionary = {} client_dataset = {} @@ -133,17 +138,22 @@ def rotate_images(client: Client, rotation: int) -> None: """ import numpy as np + from math import prod images = client.data['x'] - - if rotation >0 : - + + if rotation > 0 : + rotated_images = [] for img in images: + orig_shape = img.shape + img_flatten = img.flatten() + rotated_img = np.rot90(img, k=rotation//90) # Rotate image by specified angle - + rotated_img = rotated_img.reshape(*orig_shape) + rotated_images.append(rotated_img) client.data['x'] = np.array(rotated_images) @@ -196,7 +206,6 @@ def data_preparation(client : Client, row_exp : dict) -> None: return - def get_dataset_heterogeneities(heterogeneity_type: str) -> dict: """ @@ -205,7 +214,7 @@ def get_dataset_heterogeneities(heterogeneity_type: str) -> dict: Arguments: heterogeneity_type : The label of the heterogeneity scenario (labels-distribution-skew, concept-shift-on-labels, quantity-skew) Returns: - A dictionary of the form {<het>: []} where <het> is the applicable heterogeneity type + dict_params: A dictionary of the form {<het>: []} where <het> is the applicable heterogeneity type """ dict_params = {} @@ -221,24 +230,22 @@ def get_dataset_heterogeneities(heterogeneity_type: str) -> dict: dict_params['skews'] = [0.1,0.2,0.6,1] return dict_params - + def setup_experiment(row_exp: dict) -> Tuple[Server, list]: - """ - Setup function to create and personalize client's data + """ Setup function to create and personalize client's data Arguments: row_exp : The current experiment's global parameters - - Returns: - - model_server : A nn model used the server in the FL protocol - list_clients : A list of Client Objects used as nodes in the FL protocol + + Returns: + model_server, list_clients: a nn model used the server in the FL protocol, a list of Client Objects used as nodes in the FL protocol + """ - from src.models import SimpleLinear + from src.models import SimpleLinear, SimpleConv from src.utils_fed import init_server_cluster import torch @@ -246,14 +253,20 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: torch.manual_seed(row_exp['seed']) - imgs_params = {'mnist': (24,1) , 'fashion-mnist': (24,1), 'kmnist': (24,1), 'cifar10': (32,3)} + imgs_params = {'mnist': (28,1) , 'fashion-mnist': (28,1), 'kmnist': (28,1), 'cifar10': (32,3)} - model_server = Server(SimpleLinear(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) + if row_exp['nn_model'] == "linear": + + model_server = Server(SimpleLinear(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) + + elif row_exp['nn_model'] == "convolutional": + + model_server = Server(SimpleConv(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) dict_clients = get_clients_data(row_exp['num_clients'], row_exp['num_samples_by_label'], row_exp['dataset'], - row_exp['seed']) + row_exp['nn_model']) for i in range(row_exp['num_clients']): @@ -262,7 +275,8 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: list_clients = add_clients_heterogeneity(list_clients, row_exp) if row_exp['exp_type'] == "client": - init_server_cluster(model_server, list_clients, row_exp, imgs_params['dataset']) + + init_server_cluster(model_server, list_clients, row_exp, imgs_params[row_exp['dataset']]) return model_server, list_clients @@ -679,27 +693,6 @@ def erode_images(x_train : ndarray, kernel_size : tuple =(3, 3)) -> ndarray: return eroded_images - -def save_results(model_server : Server, row_exp : dict ) -> None: - """ - Saves model_server in row_exp['output'] as *.pth object - - Arguments: - model_server : The nn.Module to save - row_exp : The current experiment's global parameters - """ - - import torch - - if row_exp['exp_type'] == "client" or "server": - - for cluster_id in range(row_exp['num_clusters']): - - torch.save(model_server.clusters_models[cluster_id].state_dict(), f"./results/{row_exp['output']}_{row_exp['exp_type']}_model_cluster_{cluster_id}.pth") - - return - - def get_uid(str_obj: str) -> str: """ Generates an (almost) unique Identifier given a string object. diff --git a/src/utils_results.py b/src/utils_results.py index 01e6855..5b08515 100644 --- a/src/utils_results.py +++ b/src/utils_results.py @@ -203,9 +203,9 @@ def summarize_results() -> None: list_params = path.stem.split('_') - dict_exp_results = {"exp_type" : list_params[0], "dataset": list_params[1], "dataset_type": list_params[2], "number_of_clients": list_params[3], - "samples by_client": list_params[4], "num_clusters": list_params[5], "centralized_epochs": list_params[6], - "federated_rounds": list_params[7],"accuracy": accuracy} + dict_exp_results = {"exp_type" : list_params[0], "dataset": list_params[1], "nn_model" : list_params[2], "dataset_type": list_params[3], "number_of_clients": list_params[4], + "samples by_client": list_params[5], "num_clusters": list_params[6], "centralized_epochs": list_params[7], + "federated_rounds": list_params[8],"accuracy": accuracy} try: diff --git a/src/utils_training.py b/src/utils_training.py index 91e971d..1abd81d 100644 --- a/src/utils_training.py +++ b/src/utils_training.py @@ -50,7 +50,7 @@ def run_cfl_server_side(model_server : Server, list_clients : list, row_exp : di return df_results -def run_cfl_client_side(model_server : Server, list_clients : list, row_exp : dict, init_cluster=True) -> pd.DataFrame: +def run_cfl_client_side(model_server : Server, list_clients : list, row_exp : dict) -> pd.DataFrame: """ Driver function for client-side cluster FL algorithm. The algorithm personalize training by clusters obtained from model weights (k-means). @@ -61,7 +61,6 @@ def run_cfl_client_side(model_server : Server, list_clients : list, row_exp : di main_model : Type of Server model needed list_clients : A list of Client Objects used as nodes in the FL protocol row_exp : The current experiment's global parameters - init_cluster : A boolean indicating whether to initialize cluster prior to training """ from src.utils_fed import set_client_cluster, fedavg @@ -112,7 +111,7 @@ def run_benchmark(main_model : nn.Module, list_clients : list, row_exp : dict) - torch.manual_seed(row_exp['seed']) torch.use_deterministic_algorithms(True) - curr_model = main_model if row_exp['exp_type'] == 'global-federated' else SimpleLinear() + curr_model = main_model if row_exp['exp_type'] == 'global-federated' else main_model.model match row_exp['exp_type']: @@ -202,7 +201,7 @@ def train_central(main_model, train_loader, row_exp): """ criterion = nn.CrossEntropyLoss() - optimizer=optim.SGD + optimizer=optim.SGD if row_exp['nn_model'] == "linear" else optim.Adam optimizer = optimizer(main_model.parameters(), lr=0.01) main_model.train() diff --git a/tests/refs/client_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv b/tests/refs/client_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv deleted file mode 100644 index d212f0c..0000000 --- a/tests/refs/client_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv +++ /dev/null @@ -1,7 +0,0 @@ -id,cluster_id,heterogeneity_class,accuracy -0,1,none,80.66666666666666 -1,1,erosion,57.666666666666664 -2,1,dilatation,81.66666666666667 -3,1,none,83.33333333333334 -4,1,erosion,59.333333333333336 -5,1,dilatation,81.0 diff --git a/tests/refs/global-federated_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv b/tests/refs/global-federated_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv deleted file mode 100644 index 15b9856..0000000 --- a/tests/refs/global-federated_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv +++ /dev/null @@ -1,7 +0,0 @@ -id,cluster_id,heterogeneity_class,accuracy -0,,none,74.27777777777777 -1,,erosion,74.27777777777777 -2,,dilatation,74.27777777777777 -3,,none,74.27777777777777 -4,,erosion,74.27777777777777 -5,,dilatation,74.27777777777777 diff --git a/tests/refs/pers-centralized_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv b/tests/refs/pers-centralized_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv deleted file mode 100644 index 29162ca..0000000 --- a/tests/refs/pers-centralized_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv +++ /dev/null @@ -1,7 +0,0 @@ -id,cluster_id,heterogeneity_class,accuracy -0,,none,64.66666666666666 -1,,erosion,39.5 -2,,dilatation,79.83333333333333 -3,,none,64.66666666666666 -4,,erosion,39.5 -5,,dilatation,79.83333333333333 diff --git a/tests/refs/server_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv b/tests/refs/server_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv deleted file mode 100644 index 82214a8..0000000 --- a/tests/refs/server_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv +++ /dev/null @@ -1,7 +0,0 @@ -id,cluster_id,heterogeneity_class,accuracy -0,2,none,85.66666666666667 -1,1,erosion,66.66666666666666 -2,0,dilatation,86.66666666666667 -3,2,none,88.66666666666667 -4,1,erosion,68.66666666666667 -5,0,dilatation,89.0 diff --git a/tests/test_utils_training.py b/tests/test_utils_training.py index f850bbd..5082727 100644 --- a/tests/test_utils_training.py +++ b/tests/test_utils_training.py @@ -20,7 +20,7 @@ def utils_extract_params(file_path: Path): with open (file_path, "r") as fp: - keys = ['exp_type', 'dataset' , 'heterogeneity_type' , 'num_clients', + keys = ['exp_type', 'dataset', 'nn_model', 'heterogeneity_type' , 'num_clients', 'num_samples_by_label' , 'num_clusters', 'centralized_epochs', 'federated_rounds', 'seed'] @@ -29,7 +29,7 @@ def utils_extract_params(file_path: Path): row_exp = dict( zip(keys, - parameters[:3] + [int(x) for x in parameters[3:]]) + parameters[:4] + [int(x) for x in parameters[4:]]) ) return row_exp @@ -44,7 +44,7 @@ def test_run_cfl_benchmark_oracle(): from src.utils_data import setup_experiment from src.utils_training import run_benchmark - file_path = Path("tests/refs/pers-centralized_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv") + file_path = Path("tests/refs/pers-centralized_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv") row_exp = utils_extract_params(file_path) @@ -64,7 +64,7 @@ def test_run_cfl_benchmark_fl(): from src.utils_data import setup_experiment from src.utils_training import run_benchmark - file_path = Path("tests/refs/global-federated_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv") + file_path = Path("tests/refs/global-federated_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv") row_exp = utils_extract_params(file_path) @@ -84,7 +84,7 @@ def test_run_cfl_client_side(): from src.utils_data import setup_experiment from src.utils_training import run_cfl_client_side - file_path = Path("tests/refs/client_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv") + file_path = Path("tests/refs/client_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv") row_exp = utils_extract_params(file_path) @@ -104,7 +104,7 @@ def test_run_cfl_server_side(): from src.utils_data import setup_experiment from src.utils_training import run_cfl_server_side - file_path = Path("tests/refs/server_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv") + file_path = Path("tests/refs/server_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv") row_exp = utils_extract_params(file_path) -- GitLab