diff --git a/src/fedclass.py b/src/fedclass.py index d859bdee3793b46ea288ca35031e264e973ec053..4a1a3476ccbb86309ea64fdb796dfe0b3ff07c55 100644 --- a/src/fedclass.py +++ b/src/fedclass.py @@ -12,7 +12,7 @@ class Client: """Initialize the Client object - Args: + Arguments: id : int unique client identifier data : dict @@ -70,7 +70,7 @@ class Server: def __init__(self,model,num_clusters: int=None): """Initialize a Server object with an empty dictionary of cluster_models - Args: + Arguments: model: nn.Module The nn learing model the server is associated with num_clusters: int diff --git a/src/metrics.py b/src/metrics.py index eadbd06214a1078f3ae001f7a5324cc0b9682b0f..651b3240611c4fcec1cd1125213484d916a3ead3 100644 --- a/src/metrics.py +++ b/src/metrics.py @@ -3,7 +3,7 @@ def calc_global_metrics(labels_true: list, labels_pred: list) -> dict: """ Calculate global metrics based on model weights - Args: + Arguments: labels_true : list list of ground truth labels labels_pred : list diff --git a/src/models.py b/src/models.py index 54fe9aa3ba461a4369770f1148aea6831a4614d6..54d82f132f36dd65d40ae40b2975ee2ca867ed0b 100644 --- a/src/models.py +++ b/src/models.py @@ -3,13 +3,13 @@ import torch.nn as nn import torch.nn.functional as F -class SimpleLinear(nn.Module): +class SimpleLinear2(nn.Module): """ Fully connected neural network with a single hidden layer of default size 200 and ReLU activations""" def __init__(self, h1=200): """ Initialization function - Args: + Arguments: h1: int Desired size of the hidden layer """ @@ -21,11 +21,12 @@ class SimpleLinear(nn.Module): """ Forward pass function through the network - Args: + Arguments: x : torch.Tensor input image of size 28 x 28 - Returns: log_softmax probabilities of the output layer + Returns: + log_softmax probabilities of the output layer """ x = x.view(-1, 28 * 28) @@ -34,33 +35,33 @@ class SimpleLinear(nn.Module): return F.log_softmax(x, dim=1) -class SimpleConv(nn.Module): +class SimpleLinear(nn.Module): """ Convolutional neural network with 3 convolutional layers and one fully connected layer """ - def __init__(self): + def __init__(self, in_size, n_channels): """ Initialization function """ - super(SimpleConv, self).__init__() - # convolutional layer - self.conv1 = nn.Conv2d(3, 16, 3, padding=1) - self.conv2 = nn.Conv2d(16, 32, 3, padding = 1) - self.conv3 = nn.Conv2d(32, 16, 3, padding = 1) - # max pooling layer - self.pool = nn.MaxPool2d(2, 2) + super(SimpleLinear, self).__init__() + + self.conv1 = nn.Conv2d(n_channels, 16, 3, padding=1) + self.conv2 = nn.Conv2d(16, 32, 3, padding=1) + self.conv3 = nn.Conv2d(32, 16, 3, padding=1) - # Fully connected layer - self.fc1 = nn.Linear(16 * 4 * 4, 10) + self.img_final_size = int(in_size / 8) - # Dropout + self.fc1 = nn.Linear(16 * self.img_final_size * self.img_final_size, 10) + + self.pool = nn.MaxPool2d(2, 2) + self.dropout = nn.Dropout(p=0.2) def flatten(self, x : torch.Tensor): """Function to flatten a layer - Args: + Arguments: x : torch.Tensor Returns: @@ -72,7 +73,7 @@ class SimpleConv(nn.Module): def forward(self, x : torch.Tensor): """ Forward pass through the network which returns the softmax probabilities of the output layer - Args: + Arguments: x : torch.Tensor input image to use for training """ diff --git a/src/utils_data.py b/src/utils_data.py index eccfd8049c43838b4a65c67762f21a05bbe4fabd..0935676903a97fe6756798025e0432ec9b11c0d2 100644 --- a/src/utils_data.py +++ b/src/utils_data.py @@ -7,7 +7,7 @@ def shuffle_list(list_samples : int, seed : int) -> list: """Function to shuffle the samples list - Args: + Arguments: list_samples : A list of samples to shuffle seed : Randomization seed for reproducible results @@ -29,13 +29,12 @@ def shuffle_list(list_samples : int, seed : int) -> list: -def create_label_dict(dataset : dict, seed : int) -> dict: +def create_label_dict(dataset : dict) -> dict: """ Create a dictionary of dataset samples - Args: + Arguments: dataset: The name of the dataset to use ('fashion-mnist', 'mnist', or 'kmnist') - seed : Randomization seed for reproducible results Returns: A dictionary of data of the form {'x': [], 'y': []} @@ -60,6 +59,12 @@ def create_label_dict(dataset : dict, seed : int) -> dict: elif dataset == 'mnist': mnist = torchvision.datasets.MNIST("datasets", download=True) (x_train, y_train) = mnist.data, mnist.targets + x_train = x_train.unsqueeze(1) + + elif dataset == "cifar10": + cifar10 = torchvision.datasets.CIFAR10("datasets", download=True) + (x_train, y_train) = cifar10.data, cifar10.targets + x_train = np.transpose(x_train, (0, 3, 1, 2)) elif dataset == 'kmnist': (x_train, y_train), _ = kmnist.load_data() @@ -71,21 +76,18 @@ def create_label_dict(dataset : dict, seed : int) -> dict: for label in range(10): - label_indices = np.where(y_train == label)[0] - + label_indices = np.where(np.array(y_train) == label)[0] label_samples_x = x_train[label_indices] - label_dict[label] = label_samples_x return label_dict - def get_clients_data(num_clients : int, num_samples_by_label : int, dataset : dict, seed : int) -> dict: """Distribute a dataset evenly accross num_clients clients. Works with datasets with 10 labels - Args: + Arguments: num_clients : Number of clients of interest num_samples_by_label : Number of samples of each labels by client @@ -96,7 +98,7 @@ def get_clients_data(num_clients : int, num_samples_by_label : int, dataset : di import numpy as np - label_dict = create_label_dict(dataset, seed) + label_dict = create_label_dict(dataset) clients_dictionary = {} client_dataset = {} @@ -125,7 +127,7 @@ def rotate_images(client: Client, rotation: int) -> None: """ Rotate a Client's images, used for ``concept shift on features'' - Args: + Arguments: client : A Client object whose dataset images we want to rotate rotation : the rotation angle to apply 0 < angle < 360 """ @@ -154,7 +156,7 @@ def data_preparation(client : Client, row_exp : dict) -> None: """Saves Dataloaders of train and test data in the Client attributes - Args: + Arguments: client : The client object to modify row_exp : The current experiment's global parameters """ @@ -200,9 +202,9 @@ def get_dataset_heterogeneities(heterogeneity_type: str) -> dict: """ Retrieves the "skew" and "ratio" attributes of a given heterogeneity type - Args: + Arguments: heterogeneity_type : The label of the heterogeneity scenario (labels-distribution-skew, concept-shift-on-labels, quantity-skew) - Returns + Returns: A dictionary of the form {<het>: []} where <het> is the applicable heterogeneity type """ dict_params = {} @@ -226,7 +228,7 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: """ Setup function to create and personalize client's data - Args: + Arguments: row_exp : The current experiment's global parameters Returns: @@ -237,13 +239,16 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: """ from src.models import SimpleLinear + from src.utils_fed import init_server_cluster import torch list_clients = [] torch.manual_seed(row_exp['seed']) - model_server = Server(SimpleLinear()) + imgs_params = {'mnist': (24,1) , 'fashion-mnist': (24,1), 'kmnist': (24,1), 'cifar10': (32,3)} + + model_server = Server(SimpleLinear(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) dict_clients = get_clients_data(row_exp['num_clients'], row_exp['num_samples_by_label'], @@ -255,7 +260,10 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: list_clients.append(Client(i, dict_clients[i])) list_clients = add_clients_heterogeneity(list_clients, row_exp) - + + if row_exp['exp_type'] == "client": + init_server_cluster(model_server, list_clients, row_exp, imgs_params['dataset']) + return model_server, list_clients @@ -263,7 +271,7 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: def add_clients_heterogeneity(list_clients: list, row_exp: dict) -> list: """ Utility function to apply the relevant heterogeneity classes to each client - Args: + Arguments: list_clients : List of Client Objects with specific heterogeneity_class row_exp : The current experiment's global parameters Returns: @@ -296,7 +304,7 @@ def apply_label_swap(list_clients : list, row_exp : dict, list_swaps : list) -> """ Utility function to apply label swaps on Client images - Args: + Arguments: list_clients : List of Client Objects with specific heterogeneity_class row_exp : The current experiment's global parameters list_swap : List containing the labels to swap by heterogeneity class @@ -334,7 +342,7 @@ def apply_rotation(list_clients : list, row_exp : dict) -> list: """ Utility function to apply rotation 0,90,180 and 270 to 1/4 of Clients - Args: + Arguments: list_clients : List of Client Objects with specific heterogeneity_class row_exp : The current experiment's global parameters @@ -373,7 +381,7 @@ def apply_labels_skew(list_clients : list, row_exp : dict, list_skews : list, li """ Utility function to apply label skew to Clients' data - Args: + Arguments: list_clients : List of Client Objects with specific heterogeneity_class row_exp : The current experiment's global parameters @@ -413,7 +421,7 @@ def apply_quantity_skew(list_clients : list, row_exp : dict, list_skews : list) For each element in list_skews, apply the skew to an equal subset of Clients - Args: + Arguments: list_clients : List of Client Objects with specific heterogeneity_class row_exp : The current experiment's global parameters list_skew : List of float 0 < i < 1 with quantity skews to subsample data @@ -455,7 +463,7 @@ def apply_features_skew(list_clients : list, row_exp : dict) -> list : """ Utility function to apply features skew to Clients' data - Args: + Arguments: list_clients : List of Client Objects with specific heterogeneity_class row_exp : The current experiment's global parameters @@ -500,7 +508,7 @@ def swap_labels(labels : list, client : Client, heterogeneity_class : int) -> Cl """ Utility Function for label swapping used for concept shift on labels. Sets the attribute "heterogeneity class" - Args: + Arguments: labels : Labels to swap client : The Client object whose data we want to apply the swap on Returns: @@ -526,7 +534,7 @@ def swap_labels(labels : list, client : Client, heterogeneity_class : int) -> Cl def centralize_data(list_clients : list) -> Tuple[DataLoader, DataLoader]: """Centralize data of the federated learning setup for central model comparison - Args: + Arguments: list_clients : The list of Client Objects Returns: @@ -564,7 +572,7 @@ def unbalancing(client : Client ,labels_list : list ,ratio_list: list) -> Client """ Downsample the dataset of a client with each elements of the labels_list will be downsampled by the corresponding ration of ratio_list - Args: + Arguments: client : Client whose dataset we want to downsample labels_list : Labels to downsample in the Client's dataset ratio_list : Ratios to use for downsampling the labels @@ -572,6 +580,7 @@ def unbalancing(client : Client ,labels_list : list ,ratio_list: list) -> Client import pandas as pd from imblearn.datasets import make_imbalance + from math import prod def ratio_func(y, multiplier, minority_class): @@ -584,10 +593,10 @@ def unbalancing(client : Client ,labels_list : list ,ratio_list: list) -> Client x_train = client.data['x'] y_train = client.data['y'] - (_, i_dim,j_dim) = x_train.shape + orig_shape = x_train.shape # flatten the images - X_resampled = x_train.reshape(-1, i_dim * j_dim) + X_resampled = x_train.reshape(-1, prod(orig_shape[1:])) y_resampled = y_train for i in range(len(labels_list)): @@ -599,7 +608,7 @@ def unbalancing(client : Client ,labels_list : list ,ratio_list: list) -> Client sampling_strategy=ratio_func, **{"multiplier": ratio_list[i], "minority_class": labels_list[i]}) - client.data['x'] = X_resampled.to_numpy().reshape(-1, i_dim, j_dim) + client.data['x'] = X_resampled.to_numpy().reshape(-1, *orig_shape[1:]) client.data['y'] = y_resampled return client @@ -611,7 +620,7 @@ def dilate_images(x_train : ndarray, kernel_size : tuple = (3, 3)) -> ndarray: Make image 'bolder' for features distribution skew setup - Args: + Arguments: x_train : Input batch of images (3D array with shape (n, height, width)). kernel_size : Size of the structuring element/kernel for dilation. @@ -643,7 +652,7 @@ def erode_images(x_train : ndarray, kernel_size : tuple =(3, 3)) -> ndarray: Perform erosion operation on a batch of images using a given kernel. Make image 'finner' for features distribution skew setup - Args: + Arguments: x_train : Input batch of images (3D array with shape (n, height, width)). kernel_size : Size of the structuring element/kernel for erosion. @@ -675,7 +684,7 @@ def save_results(model_server : Server, row_exp : dict ) -> None: """ Saves model_server in row_exp['output'] as *.pth object - Args: + Arguments: model_server : The nn.Module to save row_exp : The current experiment's global parameters """ diff --git a/src/utils_fed.py b/src/utils_fed.py index 3b98f0771f1334f579b880db9b01701beaa4ea3b..7544e235ae1785d13d1e6d6e318648dbea79ecf6 100644 --- a/src/utils_fed.py +++ b/src/utils_fed.py @@ -7,7 +7,7 @@ def send_server_model_to_client(list_clients : list, my_server : Server) -> None """ Function to copy the Server model to client attributes in a FL protocol - Args: + Arguments: list_clients : List of Client objects on which to set the parameter `model' my_server : Server object with the model to copy """ @@ -23,7 +23,7 @@ def send_server_model_to_client(list_clients : list, my_server : Server) -> None def send_cluster_models_to_clients(list_clients : list , my_server : Server) -> None: """ Function to copy Server modelm to clients based on attribute client.cluster_id - Args: + Arguments: list_clients : List of Clients to update my_server : Server from which to fetch models """ @@ -43,7 +43,7 @@ def model_avg(list_clients : list) -> nn.Module: """ Utility function for the fed_avg function which creates a new model with weights set to the weighted average of - Args: + Arguments: list_clients : List of Client whose models we want to use to perform the weighted average Returns: @@ -66,8 +66,7 @@ def model_avg(list_clients : list) -> nn.Module: data_size = len(client.data_loader['train'].dataset) - weight = data_size / total_data_size - + weight = data_size / total_data_size weighted_avg_param += client.model.state_dict()[name] * weight param.data = weighted_avg_param #TODO: make more explicit @@ -81,22 +80,20 @@ def fedavg(my_server : Server, list_clients : list) -> None: The code modifies the cluster models `my_server.cluster_models[i]' - Args: + Arguments: my_server : Server model which contains the cluster models list_clients: List of clients, each containing a PyTorch model and a data loader. """ if my_server.num_clusters == None: - # Initialize a new model + my_server.model = model_avg(list_clients) else : for cluster_id in range(my_server.num_clusters): - - # Filter clients belonging to the current cluster - + cluster_clients_list = [client for client in list_clients if client.cluster_id == cluster_id] if len(cluster_clients_list)>0 : @@ -109,11 +106,11 @@ def model_weight_matrix(list_clients : list) -> pd.DataFrame: """ Create a weight matrix DataFrame using the weights of local federated models for use in the server-side CFL - Args : + Arguments: - list_clients: List of Clients with respective models + list_clients: List of Clients with respective models - Returns + Returns: DataFrame with weights of each model as rows """ @@ -124,13 +121,11 @@ def model_weight_matrix(list_clients : list) -> pd.DataFrame: model_dict = {client.id : client.model for client in list_clients} shapes = [param.data.numpy().shape for param in next(iter(model_dict.values())).parameters()] - weight_matrix_np = np.empty((len(model_dict), sum(np.prod(shape) for shape in shapes))) for idx, (_, model) in enumerate(model_dict.items()): model_weights = np.concatenate([param.data.numpy().flatten() for param in model.parameters()]) - weight_matrix_np[idx, :] = model_weights weight_matrix = pd.DataFrame(weight_matrix_np, columns=[f'w_{i+1}' for i in range(weight_matrix_np.shape[1])]) @@ -154,7 +149,6 @@ def k_means_cluster_id(weight_matrix : pd.DataFrame, k : int, seed : int) -> pd. from sklearn.cluster import KMeans kmeans = KMeans(n_clusters=k, random_state=seed) - kmeans.fit(weight_matrix) weight_matrix['cluster'] = kmeans.labels_ @@ -168,7 +162,7 @@ def k_means_clustering(list_clients : list, num_clusters : int, seed : int) -> N """ Performs a k-mean clustering and sets the cluser_id attribute to clients based on the result - Args: + Arguments: list_clients : List of Clients on which to perform clustering num_clusters : Parameter to set the number of clusters needed seed : Random seed to allow reproducibility @@ -187,12 +181,12 @@ def k_means_clustering(list_clients : list, num_clusters : int, seed : int) -> N -def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict, p_expert_opinion : float = 0) -> None: +def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict, imgs_params: dict, p_expert_opinion : float = 0) -> None: """ Function to initialize cluster membership for client-side CFL (sets param cluster id) using a given distribution or completely at random. - Args: + Arguments: my_server : Server model containing one model per cluster list_clients : List of Clients whose model we want to initialize @@ -217,8 +211,8 @@ def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict, p_rest = (1 - p_expert_opinion) / (row_exp['num_clusters'] - 1) my_server.num_clusters = row_exp['num_clusters'] - - my_server.clusters_models = {cluster_id: SimpleLinear(h1=200) for cluster_id in range(row_exp['num_clusters'])} + + my_server.clusters_models = {cluster_id: SimpleLinear(in_size=imgs_params[0], n_channels=imgs_params[1]) for cluster_id in range(row_exp['num_clusters'])} for client in list_clients: @@ -227,7 +221,7 @@ def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict, else p_expert_opinion for x in range(row_exp['num_clusters'])] client.cluster_id = np.random.choice(range(row_exp['num_clusters']), p = probs) - + client.model = copy.deepcopy(my_server.clusters_models[client.cluster_id]) return @@ -237,7 +231,7 @@ def loss_calculation(model : nn.modules, train_loader : DataLoader) -> float: """ Utility function to calculate average_loss across all samples <train_loader> - Args: + Arguments: model : the input server model @@ -274,7 +268,7 @@ def loss_calculation(model : nn.modules, train_loader : DataLoader) -> float: def set_client_cluster(my_server : Server, list_clients : list, row_exp : dict) -> None: """ Function to calculate cluster membership for client-side CFL (sets param cluster id) - Args: + Arguments: my_server : Server model containing one model per cluster list_clients : List of Clients whose model we want to initialize diff --git a/src/utils_logging.py b/src/utils_logging.py index 37ef6e0bdf69a1b33f2eb66455002334d3bbe0cb..ba4cd9390b791277b251f0aceae972be96e590b0 100644 --- a/src/utils_logging.py +++ b/src/utils_logging.py @@ -16,7 +16,7 @@ def cprint(msg: str, lvl: str = "info") -> None: """ Print message to the console at the desired logging level. - Args: + Arguments: msg (str): Message to print. lvl (str): Logging level between "debug", "info", "warning", "error" and "critical". The default value is "info". diff --git a/src/utils_results.py b/src/utils_results.py index 6252519078b1ca11bd07da3008d2b7aa422db284..01e685518de02bb9c62029e203b019856e52fed4 100644 --- a/src/utils_results.py +++ b/src/utils_results.py @@ -1,7 +1,8 @@ from pandas import DataFrame from pathlib import Path - +from torch import tensor + def save_histograms() -> None: @@ -50,7 +51,7 @@ def append_empty_clusters(list_clusters : list) -> list: """ Utility function for ``get_clusters'' to handle the situation where some clusters are empty by appending the clusters ID - Args: + Arguments: list_clusters: List of clusters with clients Returns: @@ -71,8 +72,10 @@ def append_empty_clusters(list_clusters : list) -> list: -def get_z_nclients(df_results, x_het, y_clust, labels_heterogeneity): +def get_z_nclients(df_results : dict, x_het : list, y_clust : list, labels_heterogeneity : list) -> list: + """ Returns the number of clients associated with a given heterogeneity class for each cluster""" + z_nclients = [0]* len(x_het) for i in range(len(z_nclients)): @@ -84,11 +87,22 @@ def get_z_nclients(df_results, x_het, y_clust, labels_heterogeneity): +def plot_img(img : tensor) -> None: + + """Utility function to plot an image of any shape""" + + from torchvision import transforms + import matplotlib.pyplot as plt + + plt.imshow(transforms.ToPILImage()(img)) + + + def plot_histogram_clusters(df_results: DataFrame, title : str) -> None: """ Function to create 3D Histograms of clients to cluster assignments showing client's heterogeneity class - Args: + Arguments: df_results : DataFrame containing all parameters from the resulting csv files diff --git a/src/utils_training.py b/src/utils_training.py index 6efdf63dbbbf7fe4d8c81a26ed9b6405d207e5f1..91e971d7e10369ceeb1f45de38db0e832430a447 100644 --- a/src/utils_training.py +++ b/src/utils_training.py @@ -16,13 +16,11 @@ def run_cfl_server_side(model_server : Server, list_clients : list, row_exp : di """ Driver function for server-side cluster FL algorithm. The algorithm personalize training by clusters obtained from model weights (k-means). - Args: - - model_server : The nn.Module to save - - list_clients : A list of Client Objects used as nodes in the FL protocol - row_exp : The current experiment's global parameters + Arguments: + main_model : Type of Server model needed + list_clients : A list of Client Objects used as nodes in the FL protocol + row_exp : The current experiment's global parameters """ from src.utils_fed import k_means_clustering import copy @@ -40,40 +38,36 @@ def run_cfl_server_side(model_server : Server, list_clients : list, row_exp : di model_server = train_federated(model_server, list_clients, row_exp, use_cluster_models = True) - list_clients = add_clients_accuracies(model_server, list_clients) + for client in list_clients : + + acc = test_model(model_server.clusters_models[client.cluster_id], client.data_loader['test']) + + setattr(client, 'accuracy', acc) + df_results = pd.DataFrame.from_records([c.to_dict() for c in list_clients]) return df_results - - def run_cfl_client_side(model_server : Server, list_clients : list, row_exp : dict, init_cluster=True) -> pd.DataFrame: """ Driver function for client-side cluster FL algorithm. The algorithm personalize training by clusters obtained from model weights (k-means). - Args: - - model_server : The nn.Module to save - list_clients : A list of Client Objects used as nodes in the FL protocol + Arguments: + main_model : Type of Server model needed + list_clients : A list of Client Objects used as nodes in the FL protocol row_exp : The current experiment's global parameters - - init_clusters : boolean indicating whether cluster assignement is done before initial training - + init_cluster : A boolean indicating whether to initialize cluster prior to training """ - from src.utils_fed import init_server_cluster, set_client_cluster, fedavg + from src.utils_fed import set_client_cluster, fedavg import torch torch.manual_seed(row_exp['seed']) - - if init_cluster == True : - - init_server_cluster(model_server, list_clients, row_exp, p_expert_opinion=0.0) for _ in range(row_exp['federated_rounds']): @@ -85,22 +79,25 @@ def run_cfl_client_side(model_server : Server, list_clients : list, row_exp : di set_client_cluster(model_server, list_clients, row_exp) - list_clients = add_clients_accuracies(model_server, list_clients) + for client in list_clients : + + acc = test_model(model_server.clusters_models[client.cluster_id], client.data_loader['test']) + + setattr(client, 'accuracy', acc) df_results = pd.DataFrame.from_records([c.to_dict() for c in list_clients]) return df_results - + def run_benchmark(main_model : nn.Module, list_clients : list, row_exp : dict) -> pd.DataFrame: """ Benchmark function to calculate baseline FL results and ``optimal'' personalization results if clusters are known in advance - Args: + Arguments: + main_model : Type of Server model needed - list_clients : A list of Client Objects used as nodes in the FL protocol - row_exp : The current experiment's global parameters """ @@ -129,7 +126,11 @@ def run_benchmark(main_model : nn.Module, list_clients : list, row_exp : dict) - model_trained, _ = train_central(curr_model, train_loader, row_exp) - test_benchmark(model_trained, list_clients_filtered, test_loader, row_exp) + global_acc = test_model(model_trained, test_loader) + + for client in list_clients_filtered : + + setattr(client, 'accuracy', global_acc) case 'global-federated': @@ -139,55 +140,27 @@ def run_benchmark(main_model : nn.Module, list_clients : list, row_exp : dict) - _, test_loader = centralize_data(list_clients) - test_benchmark(model_trained.model, list_clients, test_loader, row_exp) - - df_results = pd.DataFrame.from_records([c.to_dict() for c in list_clients]) - - return df_results - - -def test_benchmark(model_trained : nn.Module, list_clients : list, test_loader : DataLoader, row_exp : dict): - - """ Tests <model_trained> on test_loader (global) dataset and sets the attribute accuracy on each Client - - Args: - - list_clients : A list of Client Objects used as nodes in the FL protocol - - row_exp : The current experiment's global parameters - - main_model : Type of Server model needed - - training_type : a value frmo ['global-federated', 'pers-centralized'] - - """ - - from src.utils_training import test_model - - global_acc = test_model(model_trained, test_loader) + global_acc = test_model(model_trained.model, test_loader) - for client in list_clients : + for client in list_clients : - #client_acc = test_model(model_trained, client.data_loader['test'])*100 + setattr(client, 'accuracy', global_acc) - setattr(client, 'accuracy', global_acc) + df_results = pd.DataFrame.from_records([c.to_dict() for c in list_clients]) - return global_acc + return df_results def train_federated(main_model, list_clients, row_exp, use_cluster_models = False): """Controler function to launch federated learning - Args: - - main_model : Server model used in our experiment - - list_clients : A list of Client Objects used as nodes in the FL protocol - - row_exp : The current experiment's global parameters + Arguments: - use_cluster_models : Boolean to determine whether to use personalization by clustering + main_model: Server model used in our experiment + list_clients: A list of Client Objects used as nodes in the FL protocol + row_exp: The current experiment's global parameters + use_cluster_models: Boolean to determine whether to use personalization by clustering """ from src.utils_fed import send_server_model_to_client, send_cluster_models_to_clients, fedavg @@ -220,12 +193,10 @@ def train_central(main_model, train_loader, row_exp): """ Main training function for centralized learning - Args: + Arguments: main_model : Server model used in our experiment - train_loader : DataLoader with the dataset to use for training - row_exp : The current experiment's global parameters """ @@ -271,10 +242,8 @@ def test_model(model : nn.Module, test_loader : DataLoader) -> float: """ Calcualtes model accuracy (percentage) on the <test_loader> Dataset - Args: - + Arguments: model : the input server model - test_loader : DataLoader with the dataset to use for testing """ @@ -307,24 +276,4 @@ def test_model(model : nn.Module, test_loader : DataLoader) -> float: accuracy = (correct / total) * 100 - return accuracy - - -def add_clients_accuracies(model_server : nn.Module, list_clients : list) -> list: - - """ - Evaluates the cluster's models saved in <model_server> on the relevant list of clients and sets the attribute accuracy. - - Args: - model_server : Server object which contains the cluster models - - list_clients : list of Client objects which belong to the different clusters< - """ - - for client in list_clients : - - acc = test_model(model_server.clusters_models[client.cluster_id], client.data_loader['test']) - - setattr(client, 'accuracy', acc) - - return list_clients + return accuracy \ No newline at end of file