diff --git a/src/fedclass.py b/src/fedclass.py
index d859bdee3793b46ea288ca35031e264e973ec053..4a1a3476ccbb86309ea64fdb796dfe0b3ff07c55 100644
--- a/src/fedclass.py
+++ b/src/fedclass.py
@@ -12,7 +12,7 @@ class Client:
         
         """Initialize the Client object
 
-        Args:
+        Arguments:
             id : int
                 unique client identifier
             data : dict
@@ -70,7 +70,7 @@ class Server:
     def __init__(self,model,num_clusters: int=None):
         """Initialize a Server object with an empty dictionary of cluster_models
 
-        Args:
+        Arguments:
         model: nn.Module
             The nn learing model the server is associated with
         num_clusters: int
diff --git a/src/metrics.py b/src/metrics.py
index eadbd06214a1078f3ae001f7a5324cc0b9682b0f..651b3240611c4fcec1cd1125213484d916a3ead3 100644
--- a/src/metrics.py
+++ b/src/metrics.py
@@ -3,7 +3,7 @@ def calc_global_metrics(labels_true: list, labels_pred: list) -> dict:
 
     """ Calculate global metrics based on model weights
 
-    Args:
+    Arguments:
         labels_true : list
             list of ground truth labels
         labels_pred : list
diff --git a/src/models.py b/src/models.py
index 54fe9aa3ba461a4369770f1148aea6831a4614d6..54d82f132f36dd65d40ae40b2975ee2ca867ed0b 100644
--- a/src/models.py
+++ b/src/models.py
@@ -3,13 +3,13 @@ import torch.nn as nn
 import torch.nn.functional as F
 
 
-class SimpleLinear(nn.Module):
+class SimpleLinear2(nn.Module):
     """ Fully connected neural network with a single hidden layer of default size 200 and ReLU activations"""
     
     def __init__(self, h1=200):
         
         """ Initialization function
-        Args:
+        Arguments:
             h1: int
                 Desired size of the hidden layer 
         """
@@ -21,11 +21,12 @@ class SimpleLinear(nn.Module):
         
         """ Forward pass function through the network
         
-        Args:
+        Arguments:
             x : torch.Tensor
                 input image of size 28 x 28
 
-        Returns: log_softmax probabilities of the output layer
+        Returns: 
+            log_softmax probabilities of the output layer
         """
         
         x = x.view(-1, 28 * 28)
@@ -34,33 +35,33 @@ class SimpleLinear(nn.Module):
         return F.log_softmax(x, dim=1)
     
 
-class SimpleConv(nn.Module):
+class SimpleLinear(nn.Module):
 
     """ Convolutional neural network with 3 convolutional layers and one fully connected layer
     """
 
-    def __init__(self):
+    def __init__(self,  in_size, n_channels):
         """ Initialization function
         """
-        super(SimpleConv, self).__init__()
-        # convolutional layer
-        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
-        self.conv2 = nn.Conv2d(16, 32, 3, padding = 1)
-        self.conv3 = nn.Conv2d(32, 16, 3, padding = 1)
-        # max pooling layer
-        self.pool = nn.MaxPool2d(2, 2)
+        super(SimpleLinear, self).__init__()
+                
+        self.conv1 = nn.Conv2d(n_channels, 16, 3, padding=1)
+        self.conv2 = nn.Conv2d(16, 32, 3,  padding=1)
+        self.conv3 = nn.Conv2d(32, 16, 3,  padding=1)
         
-        # Fully connected layer
-        self.fc1 = nn.Linear(16 * 4 * 4, 10)
+        self.img_final_size = int(in_size / 8)
         
-        # Dropout
+        self.fc1 = nn.Linear(16 * self.img_final_size * self.img_final_size, 10)
+
+        self.pool = nn.MaxPool2d(2, 2)
+
         self.dropout = nn.Dropout(p=0.2)
 
     def flatten(self, x : torch.Tensor):
     
         """Function to flatten a layer
         
-            Args: 
+            Arguments: 
                 x : torch.Tensor
 
             Returns:
@@ -72,7 +73,7 @@ class SimpleConv(nn.Module):
     def forward(self, x : torch.Tensor):
         """ Forward pass through the network which returns the softmax probabilities of the output layer
 
-        Args:
+        Arguments:
             x : torch.Tensor
                 input image to use for training
         """
diff --git a/src/utils_data.py b/src/utils_data.py
index eccfd8049c43838b4a65c67762f21a05bbe4fabd..0935676903a97fe6756798025e0432ec9b11c0d2 100644
--- a/src/utils_data.py
+++ b/src/utils_data.py
@@ -7,7 +7,7 @@ def shuffle_list(list_samples : int, seed : int) -> list:
     
     """Function to shuffle the samples list
 
-    Args:
+    Arguments:
         list_samples : A list of samples to shuffle
         seed : Randomization seed for reproducible results
     
@@ -29,13 +29,12 @@ def shuffle_list(list_samples : int, seed : int) -> list:
 
 
 
-def create_label_dict(dataset : dict, seed : int) -> dict:
+def create_label_dict(dataset : dict) -> dict:
    
     """ Create a dictionary of dataset samples 
 
-    Args:
+    Arguments:
         dataset: The name of the dataset to use ('fashion-mnist', 'mnist', or 'kmnist')
-        seed : Randomization seed for reproducible results
    
     Returns:
         A dictionary of data of the form {'x': [], 'y': []}
@@ -60,6 +59,12 @@ def create_label_dict(dataset : dict, seed : int) -> dict:
     elif dataset == 'mnist':
         mnist = torchvision.datasets.MNIST("datasets", download=True)
         (x_train, y_train) = mnist.data, mnist.targets
+        x_train = x_train.unsqueeze(1)
+
+    elif dataset == "cifar10":
+        cifar10 = torchvision.datasets.CIFAR10("datasets", download=True)
+        (x_train, y_train) = cifar10.data, cifar10.targets
+        x_train = np.transpose(x_train, (0, 3, 1, 2))
 
     elif dataset == 'kmnist':
         (x_train, y_train), _ = kmnist.load_data()
@@ -71,21 +76,18 @@ def create_label_dict(dataset : dict, seed : int) -> dict:
 
     for label in range(10):
        
-        label_indices = np.where(y_train == label)[0]
-       
+        label_indices = np.where(np.array(y_train) == label)[0]   
         label_samples_x = x_train[label_indices]
-          
         label_dict[label] = label_samples_x
         
     return label_dict
 
 
-
 def get_clients_data(num_clients : int, num_samples_by_label : int, dataset : dict, seed : int) -> dict:
     
     """Distribute a dataset evenly accross num_clients clients. Works with datasets with 10 labels
     
-    Args:
+    Arguments:
         num_clients : Number of clients of interest
             
         num_samples_by_label : Number of samples of each labels by client
@@ -96,7 +98,7 @@ def get_clients_data(num_clients : int, num_samples_by_label : int, dataset : di
     
     import numpy as np 
 
-    label_dict = create_label_dict(dataset, seed)
+    label_dict = create_label_dict(dataset)
 
     clients_dictionary = {}
     client_dataset = {}
@@ -125,7 +127,7 @@ def rotate_images(client: Client, rotation: int) -> None:
     
     """ Rotate a Client's images, used for ``concept shift on features''
     
-    Args:
+    Arguments:
         client : A Client object whose dataset images we want to rotate
         rotation : the rotation angle to apply  0 < angle < 360
     """
@@ -154,7 +156,7 @@ def data_preparation(client : Client, row_exp : dict) -> None:
     
     """Saves Dataloaders of train and test data in the Client attributes 
     
-    Args:
+    Arguments:
         client : The client object to modify
         row_exp : The current experiment's global parameters
     """
@@ -200,9 +202,9 @@ def get_dataset_heterogeneities(heterogeneity_type: str) -> dict:
     """
     Retrieves the "skew" and "ratio" attributes of a given heterogeneity type
 
-    Args:
+    Arguments:
         heterogeneity_type : The label of the heterogeneity scenario (labels-distribution-skew, concept-shift-on-labels, quantity-skew)
-    Returns
+    Returns:
         A dictionary of the form {<het>: []} where <het> is the applicable heterogeneity type 
     """
     dict_params = {}
@@ -226,7 +228,7 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
     """
     Setup function to create and personalize client's data 
 
-    Args:
+    Arguments:
         row_exp : The current experiment's global parameters
     
     Returns:
@@ -237,13 +239,16 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
     """
 
     from src.models import SimpleLinear
+    from src.utils_fed import init_server_cluster
     import torch
     
     list_clients = []
     
     torch.manual_seed(row_exp['seed'])
 
-    model_server = Server(SimpleLinear())
+    imgs_params = {'mnist': (24,1) , 'fashion-mnist': (24,1), 'kmnist': (24,1), 'cifar10': (32,3)}
+
+    model_server = Server(SimpleLinear(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1]))
 
     dict_clients = get_clients_data(row_exp['num_clients'],
                                     row_exp['num_samples_by_label'],
@@ -255,7 +260,10 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
         list_clients.append(Client(i, dict_clients[i]))
 
     list_clients = add_clients_heterogeneity(list_clients, row_exp)
-   
+    
+    if row_exp['exp_type'] == "client":
+        init_server_cluster(model_server, list_clients, row_exp, imgs_params['dataset'])
+
     return model_server, list_clients
 
 
@@ -263,7 +271,7 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
 def add_clients_heterogeneity(list_clients: list, row_exp: dict) -> list:
     """ Utility function to apply the relevant heterogeneity classes to each client
     
-    Args:
+    Arguments:
         list_clients : List of Client Objects with specific heterogeneity_class 
         row_exp : The current experiment's global parameters
     Returns:
@@ -296,7 +304,7 @@ def apply_label_swap(list_clients : list, row_exp : dict, list_swaps : list) ->
     
     """ Utility function to apply label swaps on Client images
 
-    Args:
+    Arguments:
         list_clients : List of Client Objects with specific heterogeneity_class 
         row_exp : The current experiment's global parameters
         list_swap : List containing the labels to swap by heterogeneity class
@@ -334,7 +342,7 @@ def apply_rotation(list_clients : list, row_exp : dict) -> list:
 
     """ Utility function to apply rotation 0,90,180 and 270 to 1/4 of Clients 
 
-    Args:
+    Arguments:
         list_clients : List of Client Objects with specific heterogeneity_class 
         row_exp : The current experiment's global parameters
     
@@ -373,7 +381,7 @@ def apply_labels_skew(list_clients : list, row_exp : dict, list_skews : list, li
     
     """ Utility function to apply label skew to Clients' data 
 
-    Args:
+    Arguments:
         list_clients : List of Client Objects with specific heterogeneity_class 
         row_exp : The current experiment's global parameters
     
@@ -413,7 +421,7 @@ def apply_quantity_skew(list_clients : list, row_exp : dict, list_skews : list)
      For each element in list_skews, apply the skew to an equal subset of Clients 
 
 
-    Args:
+    Arguments:
         list_clients : List of Client Objects with specific heterogeneity_class 
         row_exp : The current experiment's global parameters
         list_skew : List of float 0 < i < 1  with quantity skews to subsample data
@@ -455,7 +463,7 @@ def apply_features_skew(list_clients : list, row_exp : dict) -> list :
     
     """ Utility function to apply features skew to Clients' data 
 
-    Args:
+    Arguments:
         list_clients : List of Client Objects with specific heterogeneity_class 
         row_exp : The current experiment's global parameters
     
@@ -500,7 +508,7 @@ def swap_labels(labels : list, client : Client, heterogeneity_class : int) -> Cl
 
     """ Utility Function for label swapping used for concept shift on labels. Sets the attribute "heterogeneity class"
     
-    Args:
+    Arguments:
         labels : Labels to swap
         client : The Client object whose data we want to apply the swap on
     Returns:
@@ -526,7 +534,7 @@ def swap_labels(labels : list, client : Client, heterogeneity_class : int) -> Cl
 def centralize_data(list_clients : list) -> Tuple[DataLoader, DataLoader]:
     """Centralize data of the federated learning setup for central model comparison
 
-    Args:
+    Arguments:
         list_clients : The list of Client Objects
 
     Returns:
@@ -564,7 +572,7 @@ def unbalancing(client : Client ,labels_list : list ,ratio_list: list) -> Client
     
     """ Downsample the dataset of a client with each elements of the labels_list will be downsampled by the corresponding ration of ratio_list
 
-    Args: 
+    Arguments: 
         client : Client whose dataset we want to downsample
         labels_list : Labels to downsample in the Client's dataset
         ratio_list : Ratios to use for downsampling the labels
@@ -572,6 +580,7 @@ def unbalancing(client : Client ,labels_list : list ,ratio_list: list) -> Client
     
     import pandas as pd
     from imblearn.datasets import make_imbalance
+    from math import prod
 
     def ratio_func(y, multiplier, minority_class):
     
@@ -584,10 +593,10 @@ def unbalancing(client : Client ,labels_list : list ,ratio_list: list) -> Client
     x_train = client.data['x']
     y_train = client.data['y']
     
-    (_, i_dim,j_dim) = x_train.shape
+    orig_shape = x_train.shape
     
      # flatten the images 
-    X_resampled = x_train.reshape(-1, i_dim * j_dim)
+    X_resampled = x_train.reshape(-1, prod(orig_shape[1:]))
     y_resampled = y_train
     
     for i in range(len(labels_list)):
@@ -599,7 +608,7 @@ def unbalancing(client : Client ,labels_list : list ,ratio_list: list) -> Client
                 sampling_strategy=ratio_func,
                 **{"multiplier": ratio_list[i], "minority_class": labels_list[i]})
 
-    client.data['x'] = X_resampled.to_numpy().reshape(-1, i_dim, j_dim)
+    client.data['x'] = X_resampled.to_numpy().reshape(-1, *orig_shape[1:])
     client.data['y'] = y_resampled
     
     return client
@@ -611,7 +620,7 @@ def dilate_images(x_train : ndarray, kernel_size : tuple = (3, 3)) -> ndarray:
     Make image 'bolder' for features distribution skew setup
     
     
-    Args:
+    Arguments:
         x_train : Input batch of images (3D array with shape (n, height, width)).
         kernel_size : Size of the structuring element/kernel for dilation.
 
@@ -643,7 +652,7 @@ def erode_images(x_train : ndarray, kernel_size : tuple =(3, 3)) -> ndarray:
     Perform erosion operation on a batch of images using a given kernel.
     Make image 'finner' for features distribution skew setup
 
-    Args:
+    Arguments:
         x_train : Input batch of images (3D array with shape (n, height, width)).
         kernel_size :  Size of the structuring element/kernel for erosion.
 
@@ -675,7 +684,7 @@ def save_results(model_server : Server, row_exp : dict ) -> None:
     """
     Saves model_server in row_exp['output'] as *.pth object
 
-    Args:
+    Arguments:
         model_server : The nn.Module to save
         row_exp :  The current experiment's global parameters
     """
diff --git a/src/utils_fed.py b/src/utils_fed.py
index 3b98f0771f1334f579b880db9b01701beaa4ea3b..7544e235ae1785d13d1e6d6e318648dbea79ecf6 100644
--- a/src/utils_fed.py
+++ b/src/utils_fed.py
@@ -7,7 +7,7 @@ def send_server_model_to_client(list_clients : list, my_server : Server) -> None
     
     """ Function to copy the Server model to client attributes in a FL protocol
 
-    Args:
+    Arguments:
         list_clients : List of Client objects on which to set the parameter `model'
         my_server : Server object with the model to copy
     """
@@ -23,7 +23,7 @@ def send_server_model_to_client(list_clients : list, my_server : Server) -> None
 def send_cluster_models_to_clients(list_clients : list , my_server : Server) -> None:
     """ Function to copy Server modelm to clients based on attribute client.cluster_id
 
-    Args: 
+    Arguments: 
         list_clients : List of Clients to update
         my_server : Server from which to fetch models
     """
@@ -43,7 +43,7 @@ def model_avg(list_clients : list) -> nn.Module:
     """  Utility function for the fed_avg function which creates a new model
          with weights set to the weighted average of 
     
-    Args:
+    Arguments:
         list_clients : List of Client whose models we want to use to perform the weighted average
 
     Returns:
@@ -66,8 +66,7 @@ def model_avg(list_clients : list) -> nn.Module:
 
             data_size = len(client.data_loader['train'].dataset)
 
-            weight = data_size / total_data_size
-            
+            weight = data_size / total_data_size            
             weighted_avg_param += client.model.state_dict()[name] * weight
 
         param.data = weighted_avg_param #TODO: make more explicit
@@ -81,22 +80,20 @@ def fedavg(my_server : Server, list_clients : list) -> None:
     The code modifies the cluster models `my_server.cluster_models[i]'
 
     
-    Args:
+    Arguments:
         my_server : Server model which contains the cluster models
 
         list_clients: List of clients, each containing a PyTorch model and a data loader.
 
     """
     if my_server.num_clusters == None:
-        # Initialize a new model
+
         my_server.model = model_avg(list_clients)
     
     else : 
          
          for cluster_id in range(my_server.num_clusters):
-          
-            # Filter clients belonging to the current cluster
-            
+                      
             cluster_clients_list = [client for client in list_clients if client.cluster_id == cluster_id]
             
             if len(cluster_clients_list)>0 :  
@@ -109,11 +106,11 @@ def model_weight_matrix(list_clients : list) -> pd.DataFrame:
    
     """ Create a weight matrix DataFrame using the weights of local federated models for use in the server-side CFL 
 
-    Args :
+    Arguments:
 
-    list_clients: List of Clients with respective models
+        list_clients: List of Clients with respective models
          
-    Returns
+    Returns:
         DataFrame with weights of each model as rows
     """
 
@@ -124,13 +121,11 @@ def model_weight_matrix(list_clients : list) -> pd.DataFrame:
     model_dict = {client.id : client.model for client in list_clients}
 
     shapes = [param.data.numpy().shape for param in next(iter(model_dict.values())).parameters()]
-
     weight_matrix_np = np.empty((len(model_dict), sum(np.prod(shape) for shape in shapes)))
 
     for idx, (_, model) in enumerate(model_dict.items()):
 
         model_weights = np.concatenate([param.data.numpy().flatten() for param in model.parameters()])
-
         weight_matrix_np[idx, :] = model_weights
 
     weight_matrix = pd.DataFrame(weight_matrix_np, columns=[f'w_{i+1}' for i in range(weight_matrix_np.shape[1])])
@@ -154,7 +149,6 @@ def k_means_cluster_id(weight_matrix : pd.DataFrame, k : int, seed : int) -> pd.
     from sklearn.cluster import KMeans
     
     kmeans = KMeans(n_clusters=k, random_state=seed)
-
     kmeans.fit(weight_matrix)
 
     weight_matrix['cluster'] = kmeans.labels_
@@ -168,7 +162,7 @@ def k_means_clustering(list_clients : list, num_clusters : int, seed : int) -> N
 
     """ Performs a k-mean clustering and sets the cluser_id attribute to clients based on the result
     
-    Args:
+    Arguments:
         list_clients : List of Clients on which to perform clustering
         num_clusters : Parameter to set the number of clusters needed
         seed : Random seed to allow reproducibility
@@ -187,12 +181,12 @@ def k_means_clustering(list_clients : list, num_clusters : int, seed : int) -> N
 
 
 
-def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict, p_expert_opinion : float = 0) -> None:
+def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict, imgs_params: dict, p_expert_opinion : float = 0) -> None:
     
     """ Function to initialize cluster membership for client-side CFL (sets param cluster id) 
     using a given distribution or completely at random. 
     
-    Args:
+    Arguments:
         my_server : Server model containing one model per cluster
 
         list_clients : List of Clients  whose model we want to initialize
@@ -217,8 +211,8 @@ def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict,
     p_rest = (1 - p_expert_opinion) / (row_exp['num_clusters'] - 1)
 
     my_server.num_clusters = row_exp['num_clusters']
-    
-    my_server.clusters_models = {cluster_id: SimpleLinear(h1=200) for cluster_id in range(row_exp['num_clusters'])}
+
+    my_server.clusters_models = {cluster_id: SimpleLinear(in_size=imgs_params[0], n_channels=imgs_params[1]) for cluster_id in range(row_exp['num_clusters'])}
     
     
     for client in list_clients:
@@ -227,7 +221,7 @@ def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict,
                         else p_expert_opinion for x in range(row_exp['num_clusters'])] 
 
         client.cluster_id = np.random.choice(range(row_exp['num_clusters']), p = probs)
-
+        
         client.model = copy.deepcopy(my_server.clusters_models[client.cluster_id])
     
     return 
@@ -237,7 +231,7 @@ def loss_calculation(model : nn.modules, train_loader : DataLoader) -> float:
 
     """ Utility function to calculate average_loss across all samples <train_loader>
 
-    Args:
+    Arguments:
 
         model : the input server model
         
@@ -274,7 +268,7 @@ def loss_calculation(model : nn.modules, train_loader : DataLoader) -> float:
 def set_client_cluster(my_server : Server, list_clients : list, row_exp : dict) -> None:
     """ Function to calculate cluster membership for client-side CFL (sets param cluster id)
     
-     Args:
+     Arguments:
         my_server : Server model containing one model per cluster
 
         list_clients : List of Clients  whose model we want to initialize
diff --git a/src/utils_logging.py b/src/utils_logging.py
index 37ef6e0bdf69a1b33f2eb66455002334d3bbe0cb..ba4cd9390b791277b251f0aceae972be96e590b0 100644
--- a/src/utils_logging.py
+++ b/src/utils_logging.py
@@ -16,7 +16,7 @@ def cprint(msg: str, lvl: str = "info") -> None:
     """
     Print message to the console at the desired logging level.
 
-    Args:
+    Arguments:
         msg (str): Message to print.
         lvl (str): Logging level between "debug", "info", "warning", "error" and "critical".
                    The default value is "info".
diff --git a/src/utils_results.py b/src/utils_results.py
index 6252519078b1ca11bd07da3008d2b7aa422db284..01e685518de02bb9c62029e203b019856e52fed4 100644
--- a/src/utils_results.py
+++ b/src/utils_results.py
@@ -1,7 +1,8 @@
 
 from pandas import DataFrame
 from pathlib import Path
-   
+from torch import tensor
+
 
 def save_histograms() -> None:
 
@@ -50,7 +51,7 @@ def append_empty_clusters(list_clusters : list) -> list:
     """
     Utility function for ``get_clusters'' to handle the situation where some clusters are empty by appending the clusters ID
     
-    Args:
+    Arguments:
         list_clusters: List of clusters with clients
 
     Returns:
@@ -71,8 +72,10 @@ def append_empty_clusters(list_clusters : list) -> list:
 
 
 
-def get_z_nclients(df_results, x_het, y_clust, labels_heterogeneity):
+def get_z_nclients(df_results : dict, x_het : list, y_clust : list, labels_heterogeneity : list) -> list:
     
+    """ Returns the number of clients associated with a given heterogeneity class for each cluster"""
+
     z_nclients = [0]* len(x_het)
 
     for i in range(len(z_nclients)):
@@ -84,11 +87,22 @@ def get_z_nclients(df_results, x_het, y_clust, labels_heterogeneity):
 
 
 
+def plot_img(img : tensor) -> None:
+
+    """Utility function to plot an image of any shape"""
+
+    from torchvision import transforms
+    import matplotlib.pyplot as plt
+
+    plt.imshow(transforms.ToPILImage()(img))
+
+
+
 def plot_histogram_clusters(df_results: DataFrame, title : str) -> None:
     
     """ Function to create 3D Histograms of clients to cluster assignments showing client's heterogeneity class 
 
-    Args:
+    Arguments:
         
         df_results : DataFrame containing all parameters from the resulting csv files
         
diff --git a/src/utils_training.py b/src/utils_training.py
index 6efdf63dbbbf7fe4d8c81a26ed9b6405d207e5f1..91e971d7e10369ceeb1f45de38db0e832430a447 100644
--- a/src/utils_training.py
+++ b/src/utils_training.py
@@ -16,13 +16,11 @@ def run_cfl_server_side(model_server : Server, list_clients : list, row_exp : di
     """ Driver function for server-side cluster FL algorithm. The algorithm personalize training by clusters obtained
     from model weights (k-means).
     
-     Args:
-        
-        model_server : The nn.Module to save
-        
-        list_clients : A list of Client Objects used as nodes in the FL protocol
-        row_exp : The current experiment's global parameters
+    Arguments:
 
+        main_model : Type of Server model needed    
+        list_clients : A list of Client Objects used as nodes in the FL protocol  
+        row_exp : The current experiment's global parameters
     """
     from src.utils_fed import k_means_clustering
     import copy
@@ -40,40 +38,36 @@ def run_cfl_server_side(model_server : Server, list_clients : list, row_exp : di
     
     model_server = train_federated(model_server, list_clients, row_exp, use_cluster_models = True)
 
-    list_clients = add_clients_accuracies(model_server, list_clients)
+    for client in list_clients :
+
+        acc = test_model(model_server.clusters_models[client.cluster_id], client.data_loader['test'])
+        
+        setattr(client, 'accuracy', acc)
+
 
     df_results = pd.DataFrame.from_records([c.to_dict() for c in list_clients])
     
     return df_results 
 
 
-
-
 def run_cfl_client_side(model_server : Server, list_clients : list, row_exp : dict, init_cluster=True) -> pd.DataFrame:
 
     """ Driver function for client-side cluster FL algorithm. The algorithm personalize training by clusters obtained
     from model weights (k-means).
     
-     Args:
-        
-        model_server : The nn.Module to save
 
-        list_clients : A list of Client Objects used as nodes in the FL protocol
+    Arguments:
 
+        main_model : Type of Server model needed    
+        list_clients : A list of Client Objects used as nodes in the FL protocol  
         row_exp : The current experiment's global parameters
-
-        init_clusters : boolean indicating whether cluster assignement is done before initial training
-
+        init_cluster : A boolean indicating whether to initialize cluster prior to training
     """
 
-    from src.utils_fed import init_server_cluster, set_client_cluster, fedavg
+    from src.utils_fed import  set_client_cluster, fedavg
     import torch
 
     torch.manual_seed(row_exp['seed'])
-
-    if init_cluster == True : 
-        
-        init_server_cluster(model_server, list_clients, row_exp, p_expert_opinion=0.0)
     
     for _ in range(row_exp['federated_rounds']):
 
@@ -85,22 +79,25 @@ def run_cfl_client_side(model_server : Server, list_clients : list, row_exp : di
         
         set_client_cluster(model_server, list_clients, row_exp)
 
-    list_clients = add_clients_accuracies(model_server, list_clients)
+    for client in list_clients :
+
+        acc = test_model(model_server.clusters_models[client.cluster_id], client.data_loader['test'])
+        
+        setattr(client, 'accuracy', acc)
 
     df_results = pd.DataFrame.from_records([c.to_dict() for c in list_clients])
     
     return df_results
-    
+
 
 def run_benchmark(main_model : nn.Module, list_clients : list, row_exp : dict) -> pd.DataFrame:
 
     """ Benchmark function to calculate baseline FL results and ``optimal'' personalization results if clusters are known in advance
 
-    Args:
+    Arguments:
+
         main_model : Type of Server model needed    
-    
         list_clients : A list of Client Objects used as nodes in the FL protocol  
-
         row_exp : The current experiment's global parameters
     """
 
@@ -129,7 +126,11 @@ def run_benchmark(main_model : nn.Module, list_clients : list, row_exp : dict) -
 
                 model_trained, _ = train_central(curr_model, train_loader, row_exp) 
 
-                test_benchmark(model_trained, list_clients_filtered, test_loader, row_exp)
+                global_acc = test_model(model_trained, test_loader) 
+                     
+                for client in list_clients_filtered : 
+        
+                    setattr(client, 'accuracy', global_acc)
     
         case 'global-federated':
                 
@@ -139,55 +140,27 @@ def run_benchmark(main_model : nn.Module, list_clients : list, row_exp : dict) -
         
             _, test_loader = centralize_data(list_clients)
 
-            test_benchmark(model_trained.model, list_clients, test_loader, row_exp)
-
-    df_results = pd.DataFrame.from_records([c.to_dict() for c in list_clients])
-    
-    return df_results
-
-
-def test_benchmark(model_trained : nn.Module, list_clients : list, test_loader : DataLoader, row_exp : dict):
-
-    """ Tests <model_trained> on test_loader (global) dataset and sets the attribute accuracy on each Client 
-
-        Args:
-                
-            list_clients : A list of Client Objects used as nodes in the FL protocol  
-
-            row_exp : The current experiment's global parameters
-
-            main_model : Type of Server model needed
-
-            training_type : a value frmo ['global-federated', 'pers-centralized'] 
-
-    """       
-         
-    from src.utils_training import test_model
-    
-    global_acc = test_model(model_trained, test_loader) 
+            global_acc = test_model(model_trained.model, test_loader) 
                      
-    for client in list_clients : 
+            for client in list_clients : 
         
-        #client_acc = test_model(model_trained, client.data_loader['test'])*100
+                setattr(client, 'accuracy', global_acc)
 
-        setattr(client, 'accuracy', global_acc)
+    df_results = pd.DataFrame.from_records([c.to_dict() for c in list_clients])
     
-    return global_acc
+    return df_results
 
 
 def train_federated(main_model, list_clients, row_exp, use_cluster_models = False):
     
     """Controler function to launch federated learning
 
-    Args:
-        
-        main_model : Server model used in our experiment
-        
-        list_clients : A list of Client Objects used as nodes in the FL protocol  
-
-        row_exp : The current experiment's global parameters
+    Arguments:
 
-        use_cluster_models : Boolean to determine whether to use personalization by clustering
+        main_model: Server model used in our experiment
+        list_clients: A list of Client Objects used as nodes in the FL protocol  
+        row_exp: The current experiment's global parameters
+        use_cluster_models: Boolean to determine whether to use personalization by clustering
     """
 
     from src.utils_fed import send_server_model_to_client, send_cluster_models_to_clients, fedavg
@@ -220,12 +193,10 @@ def train_central(main_model, train_loader, row_exp):
 
     """ Main training function for centralized learning
     
-    Args:
+    Arguments:
 
         main_model : Server model used in our experiment
-        
         train_loader : DataLoader with the dataset to use for training
-
         row_exp : The current experiment's global parameters
 
     """
@@ -271,10 +242,8 @@ def test_model(model : nn.Module, test_loader : DataLoader) -> float:
 
     """ Calcualtes model accuracy (percentage) on the <test_loader> Dataset
     
-    Args:
-
+    Arguments:
         model : the input server model
-        
         test_loader : DataLoader with the dataset to use for testing
     """
     
@@ -307,24 +276,4 @@ def test_model(model : nn.Module, test_loader : DataLoader) -> float:
 
     accuracy = (correct / total) * 100
 
-    return accuracy
-
-
-def add_clients_accuracies(model_server : nn.Module, list_clients : list) -> list:
-
-    """
-    Evaluates the cluster's models saved in <model_server> on the relevant list of clients and sets the attribute accuracy.
-
-    Args:
-        model_server : Server object which contains the cluster models
-
-        list_clients : list of Client objects which belong to the different clusters<
-    """
-
-    for client in list_clients :
-
-        acc = test_model(model_server.clusters_models[client.cluster_id], client.data_loader['test'])
-        
-        setattr(client, 'accuracy', acc)
-
-    return list_clients
+    return accuracy
\ No newline at end of file