From 410202b8a85c7d6ab71929b2ee090ccbf199beeb Mon Sep 17 00:00:00 2001
From: Omar El Rifai <omar.void@gmail.com>
Date: Tue, 3 Sep 2024 15:54:15 +0200
Subject: [PATCH] Allow model selection between linear and convolutional

---
 driver.py                                     |   5 +-
 src/models.py                                 |  17 +--
 src/utils_data.py                             | 107 ++++++++----------
 src/utils_results.py                          |   6 +-
 src/utils_training.py                         |   7 +-
 ...tures-distribution-skew_8_100_3_5_5_42.csv |   7 --
 ...tures-distribution-skew_8_100_3_5_5_42.csv |   7 --
 ...tures-distribution-skew_8_100_3_5_5_42.csv |   7 --
 ...tures-distribution-skew_8_100_3_5_5_42.csv |   7 --
 tests/test_utils_training.py                  |  12 +-
 10 files changed, 74 insertions(+), 108 deletions(-)
 delete mode 100644 tests/refs/client_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv
 delete mode 100644 tests/refs/global-federated_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv
 delete mode 100644 tests/refs/pers-centralized_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv
 delete mode 100644 tests/refs/server_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv

diff --git a/driver.py b/driver.py
index d58195b..e28c422 100644
--- a/driver.py
+++ b/driver.py
@@ -5,6 +5,7 @@ import click
 @click.option('--exp_type', help="The experiment type to run")
 @click.option('--heterogeneity_type', help="The data heterogeneity to test (or dataset)")
 @click.option('--dataset')
+@click.option('--nn_model', help= "The training model to use ('linear (default) or 'convolutional')")
 @click.option('--num_clients', type=int)
 @click.option('--num_samples_by_label', type=int)
 @click.option('--num_clusters', type=int)
@@ -14,14 +15,14 @@ import click
 
 
 
-def main_driver(exp_type, dataset, heterogeneity_type, num_clients, num_samples_by_label, num_clusters, centralized_epochs, federated_rounds, seed):
+def main_driver(exp_type, dataset, nn_model, heterogeneity_type, num_clients, num_samples_by_label, num_clusters, centralized_epochs, federated_rounds, seed):
 
     from pathlib import Path
     import pandas as pd
 
     from src.utils_data import setup_experiment, get_uid 
 
-    row_exp = pd.Series({"exp_type": exp_type, "dataset": dataset, "heterogeneity_type": heterogeneity_type, "num_clients": num_clients,
+    row_exp = pd.Series({"exp_type": exp_type, "dataset": dataset, "nn_model" : nn_model, "heterogeneity_type": heterogeneity_type, "num_clients": num_clients,
                "num_samples_by_label": num_samples_by_label, "num_clusters": num_clusters, "centralized_epochs": centralized_epochs,
                "federated_rounds": federated_rounds, "seed": seed})
     
diff --git a/src/models.py b/src/models.py
index 54d82f1..3acd1e1 100644
--- a/src/models.py
+++ b/src/models.py
@@ -3,10 +3,10 @@ import torch.nn as nn
 import torch.nn.functional as F
 
 
-class SimpleLinear2(nn.Module):
+class SimpleLinear(nn.Module):
     """ Fully connected neural network with a single hidden layer of default size 200 and ReLU activations"""
     
-    def __init__(self, h1=200):
+    def __init__(self, in_size, n_channels):
         
         """ Initialization function
         Arguments:
@@ -14,8 +14,9 @@ class SimpleLinear2(nn.Module):
                 Desired size of the hidden layer 
         """
         super().__init__()
-        self.fc1 = nn.Linear(28*28, h1)
-        self.fc2 = nn.Linear(h1, 10)
+        self.fc1 = nn.Linear(in_size*in_size,200)
+        self.fc2 = nn.Linear(200, 10)
+        self.in_size = in_size
 
     def forward(self, x: torch.Tensor):
         
@@ -23,19 +24,19 @@ class SimpleLinear2(nn.Module):
         
         Arguments:
             x : torch.Tensor
-                input image of size 28 x 28
+                input image of size in_size x in_size
 
         Returns: 
             log_softmax probabilities of the output layer
         """
         
-        x = x.view(-1, 28 * 28)
+        x = x.view(-1, self.in_size * self.in_size)
         x = F.relu(self.fc1(x))
         x = self.fc2(x)
         return F.log_softmax(x, dim=1)
     
 
-class SimpleLinear(nn.Module):
+class SimpleConv(nn.Module):
 
     """ Convolutional neural network with 3 convolutional layers and one fully connected layer
     """
@@ -43,7 +44,7 @@ class SimpleLinear(nn.Module):
     def __init__(self,  in_size, n_channels):
         """ Initialization function
         """
-        super(SimpleLinear, self).__init__()
+        super(SimpleConv, self).__init__()
                 
         self.conv1 = nn.Conv2d(n_channels, 16, 3, padding=1)
         self.conv2 = nn.Conv2d(16, 32, 3,  padding=1)
diff --git a/src/utils_data.py b/src/utils_data.py
index 0935676..8cee10f 100644
--- a/src/utils_data.py
+++ b/src/utils_data.py
@@ -28,38 +28,39 @@ def shuffle_list(list_samples : int, seed : int) -> list:
     return shuffled_list
 
 
-
-def create_label_dict(dataset : dict) -> dict:
-   
-    """ Create a dictionary of dataset samples 
+def create_label_dict(dataset : str, nn_model : str) -> dict:
+    
+    """Create a dictionary of dataset samples
 
     Arguments:
-        dataset: The name of the dataset to use ('fashion-mnist', 'mnist', or 'kmnist')
-   
+        dataset : The name of the dataset to use (e.g 'fashion-mnist', 'mnist', or 'cifar10')
+        nn_model : the training model type ('linear' or 'convolutional') 
+
     Returns:
-        A dictionary of data of the form {'x': [], 'y': []}
+        label_dict : A dictionary of data of the form {'x': [], 'y': []}
 
     Raises:
-        Error if the dataset name is unrecognized
-    
+        Error : if the dataset name is unrecognized
     """
+    
     import sys
     import numpy as np
-    
     import torchvision
-    from tensorflow.keras.datasets import mnist, fashion_mnist
     from extra_keras_datasets import kmnist
     
-    #import torchvision
-
     if dataset == "fashion-mnist":
         fashion_mnist = torchvision.datasets.MNIST("datasets", download=True)
         (x_train, y_train) = fashion_mnist.data, fashion_mnist.targets
     
+        if nn_model == "convolutional":
+            x_train = x_train.unsqueeze(1)
+
     elif dataset == 'mnist':
         mnist = torchvision.datasets.MNIST("datasets", download=True)
         (x_train, y_train) = mnist.data, mnist.targets
-        x_train = x_train.unsqueeze(1)
+        
+        if nn_model == "convolutional":
+            x_train = x_train.unsqueeze(1)
 
     elif dataset == "cifar10":
         cifar10 = torchvision.datasets.CIFAR10("datasets", download=True)
@@ -68,6 +69,9 @@ def create_label_dict(dataset : dict) -> dict:
 
     elif dataset == 'kmnist':
         (x_train, y_train), _ = kmnist.load_data()
+
+        if nn_model == "convolutional":
+            x_train = x_train.unsqueeze(1)
     
     else:
         sys.exit("Unrecognized dataset. Please make sure you are using one of the following ['mnist', fashion-mnist', 'kmnist']")    
@@ -83,14 +87,15 @@ def create_label_dict(dataset : dict) -> dict:
     return label_dict
 
 
-def get_clients_data(num_clients : int, num_samples_by_label : int, dataset : dict, seed : int) -> dict:
+def get_clients_data(num_clients : int, num_samples_by_label : int, dataset : str, nn_model : str) -> dict:
     
     """Distribute a dataset evenly accross num_clients clients. Works with datasets with 10 labels
     
     Arguments:
         num_clients : Number of clients of interest
-            
         num_samples_by_label : Number of samples of each labels by client
+        dataset: The name of the dataset to use (e.g 'fashion-mnist', 'mnist', or 'cifar10')
+        nn_model : the training model type ('linear' or 'convolutional')
 
     Returns:
         client_dataset :  Dictionnary where each key correspond to a client index. The samples will be contained in the 'x' key and the target in 'y' key
@@ -98,7 +103,7 @@ def get_clients_data(num_clients : int, num_samples_by_label : int, dataset : di
     
     import numpy as np 
 
-    label_dict = create_label_dict(dataset)
+    label_dict = create_label_dict(dataset, nn_model)
 
     clients_dictionary = {}
     client_dataset = {}
@@ -133,17 +138,22 @@ def rotate_images(client: Client, rotation: int) -> None:
     """
     
     import numpy as np
+    from math import prod
 
     images = client.data['x']
-    
-    if rotation >0 :
-    
+
+    if rotation > 0 :
+
         rotated_images = []
     
         for img in images:
     
+            orig_shape = img.shape             
+            img_flatten = img.flatten()
+
             rotated_img = np.rot90(img, k=rotation//90)  # Rotate image by specified angle
-    
+            rotated_img = rotated_img.reshape(*orig_shape)
+
             rotated_images.append(rotated_img)   
     
         client.data['x'] = np.array(rotated_images)
@@ -196,7 +206,6 @@ def data_preparation(client : Client, row_exp : dict) -> None:
     return 
 
 
-
 def get_dataset_heterogeneities(heterogeneity_type: str) -> dict:
 
     """
@@ -205,7 +214,7 @@ def get_dataset_heterogeneities(heterogeneity_type: str) -> dict:
     Arguments:
         heterogeneity_type : The label of the heterogeneity scenario (labels-distribution-skew, concept-shift-on-labels, quantity-skew)
     Returns:
-        A dictionary of the form {<het>: []} where <het> is the applicable heterogeneity type 
+        dict_params: A dictionary of the form {<het>: []} where <het> is the applicable heterogeneity type 
     """
     dict_params = {}
 
@@ -221,24 +230,22 @@ def get_dataset_heterogeneities(heterogeneity_type: str) -> dict:
         dict_params['skews'] = [0.1,0.2,0.6,1]
 
     return dict_params
-
+    
 
 def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
 
-    """
-    Setup function to create and personalize client's data 
+    """ Setup function to create and personalize client's data 
 
     Arguments:
         row_exp : The current experiment's global parameters
-    
-    Returns:
-        
-        model_server : A nn model used the server in the FL protocol
 
-        list_clients : A list of Client Objects used as nodes in the FL protocol
+
+    Returns: 
+        model_server, list_clients: a nn model used the server in the FL protocol, a list of Client Objects used as nodes in the FL protocol
+
     """
 
-    from src.models import SimpleLinear
+    from src.models import SimpleLinear, SimpleConv
     from src.utils_fed import init_server_cluster
     import torch
     
@@ -246,14 +253,20 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
     
     torch.manual_seed(row_exp['seed'])
 
-    imgs_params = {'mnist': (24,1) , 'fashion-mnist': (24,1), 'kmnist': (24,1), 'cifar10': (32,3)}
+    imgs_params = {'mnist': (28,1) , 'fashion-mnist': (28,1), 'kmnist': (28,1), 'cifar10': (32,3)}
 
-    model_server = Server(SimpleLinear(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1]))
+    if row_exp['nn_model'] == "linear":
+        
+        model_server = Server(SimpleLinear(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) 
+    
+    elif row_exp['nn_model'] == "convolutional": 
+        
+        model_server = Server(SimpleConv(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1]))
 
     dict_clients = get_clients_data(row_exp['num_clients'],
                                     row_exp['num_samples_by_label'],
                                     row_exp['dataset'],
-                                    row_exp['seed'])    
+                                    row_exp['nn_model'])    
     
     for i in range(row_exp['num_clients']):
 
@@ -262,7 +275,8 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
     list_clients = add_clients_heterogeneity(list_clients, row_exp)
     
     if row_exp['exp_type'] == "client":
-        init_server_cluster(model_server, list_clients, row_exp, imgs_params['dataset'])
+
+        init_server_cluster(model_server, list_clients, row_exp, imgs_params[row_exp['dataset']])
 
     return model_server, list_clients
 
@@ -679,27 +693,6 @@ def erode_images(x_train : ndarray, kernel_size : tuple =(3, 3)) -> ndarray:
     return eroded_images
 
 
-
-def save_results(model_server : Server, row_exp : dict ) -> None:
-    """
-    Saves model_server in row_exp['output'] as *.pth object
-
-    Arguments:
-        model_server : The nn.Module to save
-        row_exp :  The current experiment's global parameters
-    """
-    
-    import torch
-
-    if row_exp['exp_type'] == "client" or "server":
-    
-        for cluster_id in range(row_exp['num_clusters']): 
-    
-            torch.save(model_server.clusters_models[cluster_id].state_dict(), f"./results/{row_exp['output']}_{row_exp['exp_type']}_model_cluster_{cluster_id}.pth")
-
-    return 
-
-
 def get_uid(str_obj: str) -> str:
     """
     Generates an (almost) unique Identifier given a string object.
diff --git a/src/utils_results.py b/src/utils_results.py
index 01e6855..5b08515 100644
--- a/src/utils_results.py
+++ b/src/utils_results.py
@@ -203,9 +203,9 @@ def summarize_results() -> None:
 
             list_params = path.stem.split('_')      
 
-            dict_exp_results = {"exp_type" : list_params[0], "dataset": list_params[1], "dataset_type": list_params[2], "number_of_clients": list_params[3],
-                                    "samples by_client": list_params[4], "num_clusters": list_params[5], "centralized_epochs": list_params[6],
-                                    "federated_rounds": list_params[7],"accuracy": accuracy}
+            dict_exp_results = {"exp_type" : list_params[0], "dataset": list_params[1], "nn_model" : list_params[2], "dataset_type": list_params[3], "number_of_clients": list_params[4],
+                                    "samples by_client": list_params[5], "num_clusters": list_params[6], "centralized_epochs": list_params[7],
+                                    "federated_rounds": list_params[8],"accuracy": accuracy}
 
             try:
                 
diff --git a/src/utils_training.py b/src/utils_training.py
index 91e971d..1abd81d 100644
--- a/src/utils_training.py
+++ b/src/utils_training.py
@@ -50,7 +50,7 @@ def run_cfl_server_side(model_server : Server, list_clients : list, row_exp : di
     return df_results 
 
 
-def run_cfl_client_side(model_server : Server, list_clients : list, row_exp : dict, init_cluster=True) -> pd.DataFrame:
+def run_cfl_client_side(model_server : Server, list_clients : list, row_exp : dict) -> pd.DataFrame:
 
     """ Driver function for client-side cluster FL algorithm. The algorithm personalize training by clusters obtained
     from model weights (k-means).
@@ -61,7 +61,6 @@ def run_cfl_client_side(model_server : Server, list_clients : list, row_exp : di
         main_model : Type of Server model needed    
         list_clients : A list of Client Objects used as nodes in the FL protocol  
         row_exp : The current experiment's global parameters
-        init_cluster : A boolean indicating whether to initialize cluster prior to training
     """
 
     from src.utils_fed import  set_client_cluster, fedavg
@@ -112,7 +111,7 @@ def run_benchmark(main_model : nn.Module, list_clients : list, row_exp : dict) -
     torch.manual_seed(row_exp['seed'])
     torch.use_deterministic_algorithms(True)
 
-    curr_model = main_model if row_exp['exp_type'] == 'global-federated' else SimpleLinear()
+    curr_model = main_model if row_exp['exp_type'] == 'global-federated' else main_model.model
 
     match row_exp['exp_type']:
     
@@ -202,7 +201,7 @@ def train_central(main_model, train_loader, row_exp):
     """
     criterion = nn.CrossEntropyLoss()
     
-    optimizer=optim.SGD
+    optimizer=optim.SGD if row_exp['nn_model'] == "linear" else optim.Adam
     optimizer = optimizer(main_model.parameters(), lr=0.01) 
    
     main_model.train()
diff --git a/tests/refs/client_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv b/tests/refs/client_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv
deleted file mode 100644
index d212f0c..0000000
--- a/tests/refs/client_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv
+++ /dev/null
@@ -1,7 +0,0 @@
-id,cluster_id,heterogeneity_class,accuracy
-0,1,none,80.66666666666666
-1,1,erosion,57.666666666666664
-2,1,dilatation,81.66666666666667
-3,1,none,83.33333333333334
-4,1,erosion,59.333333333333336
-5,1,dilatation,81.0
diff --git a/tests/refs/global-federated_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv b/tests/refs/global-federated_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv
deleted file mode 100644
index 15b9856..0000000
--- a/tests/refs/global-federated_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv
+++ /dev/null
@@ -1,7 +0,0 @@
-id,cluster_id,heterogeneity_class,accuracy
-0,,none,74.27777777777777
-1,,erosion,74.27777777777777
-2,,dilatation,74.27777777777777
-3,,none,74.27777777777777
-4,,erosion,74.27777777777777
-5,,dilatation,74.27777777777777
diff --git a/tests/refs/pers-centralized_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv b/tests/refs/pers-centralized_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv
deleted file mode 100644
index 29162ca..0000000
--- a/tests/refs/pers-centralized_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv
+++ /dev/null
@@ -1,7 +0,0 @@
-id,cluster_id,heterogeneity_class,accuracy
-0,,none,64.66666666666666
-1,,erosion,39.5
-2,,dilatation,79.83333333333333
-3,,none,64.66666666666666
-4,,erosion,39.5
-5,,dilatation,79.83333333333333
diff --git a/tests/refs/server_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv b/tests/refs/server_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv
deleted file mode 100644
index 82214a8..0000000
--- a/tests/refs/server_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv
+++ /dev/null
@@ -1,7 +0,0 @@
-id,cluster_id,heterogeneity_class,accuracy
-0,2,none,85.66666666666667
-1,1,erosion,66.66666666666666
-2,0,dilatation,86.66666666666667
-3,2,none,88.66666666666667
-4,1,erosion,68.66666666666667
-5,0,dilatation,89.0
diff --git a/tests/test_utils_training.py b/tests/test_utils_training.py
index f850bbd..5082727 100644
--- a/tests/test_utils_training.py
+++ b/tests/test_utils_training.py
@@ -20,7 +20,7 @@ def utils_extract_params(file_path: Path):
 
     with open (file_path, "r") as fp:
         
-        keys = ['exp_type', 'dataset' , 'heterogeneity_type' , 'num_clients',
+        keys = ['exp_type', 'dataset', 'nn_model', 'heterogeneity_type' , 'num_clients',
                 'num_samples_by_label' , 'num_clusters', 'centralized_epochs',
                 'federated_rounds', 'seed']
         
@@ -29,7 +29,7 @@ def utils_extract_params(file_path: Path):
 
         row_exp = dict(
             zip(keys,
-                parameters[:3] + [int(x) for x in  parameters[3:]])
+                parameters[:4] + [int(x) for x in  parameters[4:]])
             )
     
     return row_exp
@@ -44,7 +44,7 @@ def test_run_cfl_benchmark_oracle():
     from src.utils_data import setup_experiment    
     from src.utils_training import run_benchmark
 
-    file_path = Path("tests/refs/pers-centralized_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv")
+    file_path = Path("tests/refs/pers-centralized_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
 
     row_exp = utils_extract_params(file_path) 
    
@@ -64,7 +64,7 @@ def test_run_cfl_benchmark_fl():
     from src.utils_data import setup_experiment    
     from src.utils_training import run_benchmark
 
-    file_path = Path("tests/refs/global-federated_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv")
+    file_path = Path("tests/refs/global-federated_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
 
     row_exp = utils_extract_params(file_path) 
    
@@ -84,7 +84,7 @@ def test_run_cfl_client_side():
     from src.utils_data import setup_experiment    
     from src.utils_training import run_cfl_client_side
 
-    file_path = Path("tests/refs/client_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv")
+    file_path = Path("tests/refs/client_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
 
     row_exp = utils_extract_params(file_path) 
    
@@ -104,7 +104,7 @@ def test_run_cfl_server_side():
     from src.utils_data import setup_experiment    
     from src.utils_training import run_cfl_server_side
 
-    file_path = Path("tests/refs/server_fashion-mnist_features-distribution-skew_8_100_3_5_5_42.csv")
+    file_path = Path("tests/refs/server_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
 
     row_exp = utils_extract_params(file_path) 
    
-- 
GitLab