diff --git a/.gitignore b/.gitignore
index 0b25033a1b960e90cbfb7cb05570ffbc03c31181..4ccaa027b330ceccd99b99eeda8dff2b001392bd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,16 +1,16 @@
-*.py~
-results/*
-.vscode/*
-src/__pycache__/*
-info.log
-launch.json
-backup_results/*
-*.sh
-tests/__pycache__/*
-datasets/*
-pub/*
-data/*
-*.tgz
-*.pyc
-
+*.py~
+results/*
+.vscode/*
+src/__pycache__/*
+info.log
+launch.json
+backup_results/*
+*.sh
+tests/__pycache__/*
+datasets/*
+pub/*
+data/*
+*.tgz
+*.pyc
+
 src/__pycache__/*
\ No newline at end of file
diff --git a/README.md b/README.md
index e3a2736777e80424ef36e463e9eadefb0d47c8be..1d1d3d24404a524e12eeeb34ae522e19f7a7caea 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,12 @@
-#### Code for the paper: *Comparative Evaluation of Clustered Federated Learning Methods*
-
-##### Submited to 'The 2nd IEEE International Conference on Federated Learning Technologies and Applications (FLTA24), VALENCIA, SPAIN' 
-
-1. To reproduce the results in the paper run `driver.py` with the parameters in `exp_configs.csv`
-
-2. Each experiment will output a `.csv` file with the resuting metrics
-
-3. Histogram plots and a summary table of various experiments can be obtained running `src/utils_results.py`
-  
-
-  
+#### Code for the paper: *Comparative Evaluation of Clustered Federated Learning Methods*
+
+##### Submited to 'The 2nd IEEE International Conference on Federated Learning Technologies and Applications (FLTA24), VALENCIA, SPAIN' 
+
+1. To reproduce the results in the paper run `driver.py` with the parameters in `exp_configs.csv`
+
+2. Each experiment will output a `.csv` file with the resuting metrics
+
+3. Histogram plots and a summary table of various experiments can be obtained running `src/utils_results.py`
+  
+
+  
diff --git a/driver.py b/driver.py
index 9332a4cb0671b019de3fd7c9a1857785b6e6d42a..5229830d5d5b92d61ba3bced2eb1eb51b48081f8 100644
--- a/driver.py
+++ b/driver.py
@@ -1,11 +1,15 @@
+import os
+
+# Set the environment variable for deterministic behavior with CuBLAS (Give reproductibility with CUDA) 
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
 
 import click
 
 @click.command()
 @click.option('--exp_type', help="The experiment type to run")
-@click.option('--heterogeneity_type', help="The data heterogeneity to test (or dataset)")
 @click.option('--dataset')
 @click.option('--nn_model', help= "The training model to use ('linear (default) or 'convolutional')")
+@click.option('--heterogeneity_type', help="The data heterogeneity to test (or dataset)")
 @click.option('--num_clients', type=int)
 @click.option('--num_samples_by_label', type=int)
 @click.option('--num_clusters', type=int)
@@ -14,7 +18,6 @@ import click
 @click.option('--seed', type=int)
 
 
-
 def main_driver(exp_type, dataset, nn_model, heterogeneity_type, num_clients, num_samples_by_label, num_clusters, centralized_epochs, federated_rounds, seed):
 
     from pathlib import Path
diff --git a/exp_configs.csv b/exp_configs.csv
index 59a39fd33b4f22f9a40b9345e82c1778e67fbe71..677e5b1b7ec7685ac5ebc3e1db65c61b943b7d3e 100644
--- a/exp_configs.csv
+++ b/exp_configs.csv
@@ -1,2 +1,10 @@
-exp_type,dataset,nn_model,heterogeneity_type,num_clients,num_samples_by_label,num_clusters,centralized_epochs,federated_rounds,seed
-server,cifar10,convolutional,concept-shift-on-labels,8,600,6,3,2,42
+exp_type,dataset,nn_model,heterogeneity_type,num_clients,num_samples_by_label,num_clusters,centralized_epochs,federated_rounds,seed
+pers-centralized,cifar10,convolutional,concept-shift-on-features,48,100,4,200,0,42
+global-federated,cifar10,convolutional,concept-shift-on-features,48,100,4,20,50,42
+global-federated,cifar10,convolutional,concept-shift-on-features,48,100,4,20,100,42
+global-federated,cifar10,convolutional,concept-shift-on-features,48,100,4,20,150,42
+global-federated,cifar10,convolutional,concept-shift-on-features,48,100,4,20,200,42
+global-federated,cifar10,convolutional,concept-shift-on-features,48,100,4,50,50,42
+global-federated,cifar10,convolutional,concept-shift-on-features,48,100,4,50,100,42
+global-federated,cifar10,convolutional,concept-shift-on-features,48,100,4,50,150,42
+global-federated,cifar10,convolutional,concept-shift-on-features,48,100,4,50,200,42
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 3d5b25f724115d5e59a0ec95d067b1e28836bc8b..49dd21325bab0c39a947e406f5c214391a968472 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,13 +1,13 @@
-# Automatically generated by https://github.com/damnever/pigar.
-
-imbalanced-learn==0.12.3
-inputimeout==1.0.4
-kiwisolver==1.4.5
-matplotlib==3.9.0
-numpy==1.26.4
-opencv-python==4.10.0.84
-pandas==2.2.2
-scikit-learn==1.5.0
-scipy==1.14.0
-tensorflow==2.16.2
-
+# Automatically generated by https://github.com/damnever/pigar.
+
+imbalanced-learn==0.12.3
+inputimeout==1.0.4
+kiwisolver==1.4.5
+matplotlib==3.9.0
+numpy==1.26.4
+opencv-python==4.10.0.84
+pandas==2.2.2
+scikit-learn==1.5.0
+scipy==1.14.0
+tensorflow==2.16.2
+
diff --git a/src/models.py b/src/models.py
index 78387938fac081a9759143cae170198c354f5c47..77b903cc6b17296c66a8e0c7fb156d17b8f733de 100644
--- a/src/models.py
+++ b/src/models.py
@@ -2,23 +2,26 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
 
 def accuracy(outputs, labels):
     _, preds = torch.max(outputs, dim=1)
     return torch.tensor(torch.sum(preds == labels).item() / len(preds))
 
-
 class ImageClassificationBase(nn.Module):
-    def training_step(self, batch):
-        images, labels = batch 
+    def training_step(self, batch, device):
+        images, labels = batch
+        images, labels = images.to(device), labels.to(device) 
         out = self(images)
-        loss = F.cross_entropy(out, labels) # Calculate loss
+        loss = F.cross_entropy(out, labels)  
         return loss
     
-    def validation_step(self, batch):
-        images, labels = batch 
+    def validation_step(self, batch, device):
+        images, labels = batch
+        images, labels = images.to(device), labels.to(device)  
         out = self(images)
-        loss = F.cross_entropy(out, labels)   # Calculate loss
+        loss = F.cross_entropy(out, labels) 
         acc = accuracy(out, labels)
         return {'val_loss': loss.detach(), 'val_acc': acc}
         
@@ -32,60 +35,106 @@ class ImageClassificationBase(nn.Module):
     def epoch_end(self, epoch, result):
         print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
             epoch, result['train_loss'], result['val_loss'], result['val_acc']))
-        
-
 
 class GenericLinearModel(ImageClassificationBase):
-    
     def __init__(self, in_size, n_channels):
-        
         super().__init__()
-        
         self.in_size = in_size
-
         self.network = nn.Sequential(
-            nn.Linear(in_size*in_size,200),
-            nn.Linear(200, 10))
+            nn.Linear(in_size * in_size, 200),
+            nn.Linear(200, 10)
+        )
         
     def forward(self, xb):
         xb = xb.view(-1, self.in_size * self.in_size)
         return self.network(xb)
-        
-    
-
 
 class GenericConvModel(ImageClassificationBase):
     def __init__(self, in_size, n_channels):
         super().__init__()
-        
         self.img_final_size = int(in_size / (2**3))
-
         self.network = nn.Sequential(
             nn.Conv2d(n_channels, 32, kernel_size=3, padding=1),
+            nn.BatchNorm2d(32),
             nn.ReLU(),
             nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
+            nn.BatchNorm2d(64),
             nn.ReLU(),
-            nn.MaxPool2d(2, 2), # output: 64 x 16 x 16
-
+            nn.MaxPool2d(2, 2),  # output: 64 x 16 x 16
+            nn.Dropout(0.25),
+            
             nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
+            nn.BatchNorm2d(128),
             nn.ReLU(),
             nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
+            nn.BatchNorm2d(128),
             nn.ReLU(),
-            nn.MaxPool2d(2, 2), # output: 128 x 8 x 8
-
+            nn.MaxPool2d(2, 2),  # output: 128 x 8 x 8
+            nn.Dropout(0.25),
+            
+            
             nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
+            nn.BatchNorm2d(256),
             nn.ReLU(),
             nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
+            nn.BatchNorm2d(256),
             nn.ReLU(),
-            nn.MaxPool2d(2, 2), # output: 256 x 4 x 4
-
-            nn.Flatten(), 
+            nn.MaxPool2d(2, 2),  # output: 256 x 4 x 4
+            nn.Dropout(0.25),
+            
+            nn.Flatten(),
             nn.Linear(256 * self.img_final_size * self.img_final_size, 1024),
             nn.ReLU(),
             nn.Linear(1024, 512),
             nn.ReLU(),
-            nn.Linear(512, 10))
+            nn.Linear(512, 10)
+        )
         
     def forward(self, xb):
-        return self.network(xb)
-    
\ No newline at end of file
+        return self.network(xb)    
+        
+
+class CovNet(ImageClassificationBase):
+    def __init__(self, in_size, n_channels):
+        super().__init__()
+        self.img_final_size = int(in_size / (2**3))
+        self.network = nn.Sequential(
+            nn.Conv2d(n_channels, 32, kernel_size=3, padding=1),
+            nn.BatchNorm2d(32),
+            nn.ReLU(),
+            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
+            nn.BatchNorm2d(32),
+            nn.ReLU(),
+            nn.MaxPool2d(2, 2),  # output: 32 x 16 x 16
+            nn.Dropout(0.25),
+            
+            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
+            nn.BatchNorm2d(64),
+            nn.ReLU(),
+            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
+            nn.BatchNorm2d(64),
+            nn.ReLU(),
+            nn.MaxPool2d(2, 2),  # output: 64 x 8 x 8
+            nn.Dropout(0.25),
+            
+            
+            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
+            nn.BatchNorm2d(128),
+            nn.ReLU(),
+            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
+            nn.BatchNorm2d(128),
+            nn.ReLU(),
+            nn.MaxPool2d(2, 2),  # output: 128 x 4 x 4
+            nn.Dropout(0.25),
+            
+            nn.Flatten(),
+            nn.Linear(128 * self.img_final_size * self.img_final_size, 512),
+            nn.ReLU(),
+            nn.Linear(512, 256),
+            nn.ReLU(),
+            nn.Linear(256, 10)
+        )
+        
+    def forward(self, xb):
+        return self.network(xb)           
+        
diff --git a/src/utils_data.py b/src/utils_data.py
index 7cb231c5e34c28489e97794c507276bae012a891..d15f37bd86ba76aa26c9730079e2a32729f076d2 100644
--- a/src/utils_data.py
+++ b/src/utils_data.py
@@ -3,35 +3,6 @@ from torch.utils.data import DataLoader
 from numpy import ndarray
 from typing import Tuple
 
-
-class AddGaussianNoise(object):
-    
-    def __init__(self, mean=0., std=1.):
-        self.std = std
-        self.mean = mean
-        
-    def __call__(self, tensor):
-        import torch
-        return tensor + torch.randn(tensor.size()) * self.std + self.mean
-    
-    def __repr__(self):
-        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
-
-class AddRandomJitter(object):
-
-    def __init__(self, brightness =0.5, contrast = 1, saturation = 0.1, hue = 0.5):
-        self.brightness = brightness,
-        self.contrast = contrast, 
-        self.saturation = saturation, 
-        self.hue = hue
-    
-    def __call__(self, tensor):
-        import torchvision.transforms as transforms
-        transform = transforms.ColorJitter(brightness = self.brightness, contrast= self.contrast, 
-                               saturation = self.saturation, hue = self.hue)
-        return transform(tensor)
-    
-
 def shuffle_list(list_samples : int, seed : int) -> list: 
     
     """Function to shuffle the samples list
@@ -75,42 +46,28 @@ def create_label_dict(dataset : str, nn_model : str) -> dict:
     import sys
     import numpy as np
     import torchvision
+   
     import torchvision.transforms as transforms
 
-    transform = transforms.Compose(
-    [transforms.ToTensor(),
-     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
-     AddGaussianNoise(0., 1.),
-     AddRandomJitter()])
-
-    
-
     if dataset == "fashion-mnist":
-        fashion_mnist = torchvision.datasets.MNIST("datasets", download=True, transform=transform)
+        fashion_mnist = torchvision.datasets.MNIST("datasets", download=True)
         (x_data, y_data) = fashion_mnist.data, fashion_mnist.targets
     
-        if nn_model == "convolutional":
+        if nn_model in ["convolutional","CovNet"]:
             x_data = x_data.unsqueeze(1)
 
     elif dataset == 'mnist':
         mnist = torchvision.datasets.MNIST("datasets", download=True)
         (x_data, y_data) = mnist.data, mnist.targets
-        
-        if nn_model == "convolutional":
-            x_data = x_data.unsqueeze(1)
-
+    
     elif dataset == "cifar10":
-        cifar10 = torchvision.datasets.CIFAR10("datasets", download=True, transform=transform)
+        cifar10 = torchvision.datasets.CIFAR10("datasets", download=True)
         (x_data, y_data) = cifar10.data, cifar10.targets
-        x_data = np.transpose(x_data, (0, 3, 1, 2))
+        #x_data = np.transpose(x_data, (0, 3, 1, 2))
         
     elif dataset == 'kmnist':
-        kmnist = torchvision.datasets.KMNIST("datasets", download=True, transform=transform)
-        (x_data, y_data)  = kmnist.load_data()
-
-        if nn_model == "convolutional":
-            x_data = x_data.unsqueeze(1)
-    
+        kmnist = torchvision.datasets.KMNIST("datasets", download=True)
+        (x_data, y_data)  = kmnist.load_data()    
     else:
         sys.exit("Unrecognized dataset. Please make sure you are using one of the following ['mnist', fashion-mnist', 'kmnist']")    
 
@@ -194,10 +151,67 @@ def rotate_images(client: Client, rotation: int) -> None:
 
     return
 
+import torch 
+from sklearn.model_selection import train_test_split
+from torch.utils.data import DataLoader, TensorDataset, Dataset
+import torchvision.transforms as transforms
+
+
+class AddGaussianNoise(object):
+    
+    def __init__(self, mean=0., std=1.):
+        self.std = std
+        self.mean = mean
+        
+    def __call__(self, tensor):
+        import torch
+        return tensor + torch.randn(tensor.size()) * self.std + self.mean
+    
+    def __repr__(self):
+        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
+
+class AddRandomJitter(object):
 
+    def __init__(self, brightness =0.5, contrast = 1, saturation = 0.1, hue = 0.5):
+        self.brightness = brightness,
+        self.contrast = contrast, 
+        self.saturation = saturation, 
+        self.hue = hue
+    
+    def __call__(self, tensor):
+        import torchvision.transforms as transforms
+        transform = transforms.ColorJitter(brightness = self.brightness, contrast= self.contrast, 
+                               saturation = self.saturation, hue = self.hue)
+        return transform(tensor)
 
-def data_preparation(client : Client, row_exp : dict) -> None:
+class CustomDataset(Dataset):
     
+    def __init__(self, data, labels, transform=None):
+        # Ensure data is in (N, H, W, C) format
+        self.data = data  # Assume data is in (N, H, W, C) format
+        self.labels = labels
+        self.transform = transform
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, idx):
+        image = self.data[idx]  # Shape (H, W, C)
+        label = self.labels[idx]
+
+        # Convert image to tensor and permute to (C, H, W)
+        image = torch.tensor(image, dtype=torch.float)  # Convert to tensor
+        image = image.permute(2, 0, 1)  # Change to (C, H, W)
+
+        # Apply transformation if provided
+        if self.transform:
+            image = self.transform(image)
+
+        return image, label
+
+
+
+def data_preparation(client: Client, row_exp: dict) -> None:
     """Saves Dataloaders of train and test data in the Client attributes 
     
     Arguments:
@@ -206,37 +220,52 @@ def data_preparation(client : Client, row_exp : dict) -> None:
     """
 
     def to_device_tensor(data, device, data_dtype):
-    
         data = torch.tensor(data, dtype=data_dtype)
-        data.to(device)
+        data = data.to(device)
         return data
     
     import torch 
     from sklearn.model_selection import train_test_split
-    from torch.utils.data import DataLoader, TensorDataset
+    from torch.utils.data import DataLoader, TensorDataset, Dataset
+    import torchvision.transforms as transforms
+    import numpy as np  # Import NumPy for transpose operation
+    
+    # Define data augmentation transforms
+    train_transform = transforms.Compose([
+        transforms.RandomHorizontalFlip(),
+        AddGaussianNoise(0., 1.),
+        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
+        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize if needed
+    ])
+    
+    # Transform for validation and test data (no augmentation, just normalization)
+    test_val_transform = transforms.Compose([
+        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize if needed
+    ])
 
     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     
-    x_data, x_test, y_data, y_test = train_test_split(client.data['x'], client.data['y'], test_size=0.3, random_state=row_exp['seed'],stratify=client.data['y'])
-    x_train, x_val, y_train, y_val  = train_test_split(x_data, y_data, test_size=0.25, random_state=42) 
+    # Split into train, validation, and test sets
+    x_data, x_test, y_data, y_test = train_test_split(client.data['x'], client.data['y'], test_size=0.3, random_state=row_exp['seed'], stratify=client.data['y'])
+    x_train, x_val, y_train, y_val  = train_test_split(x_data, y_data, test_size=0.25, random_state=42)
 
-    x_train_tensor = to_device_tensor(x_train, device, torch.float32)
-    y_train_tensor = to_device_tensor(y_train, device, torch.long)
 
-    x_val_tensor = to_device_tensor(x_val, device, torch.float32)
-    y_val_tensor = to_device_tensor(y_val, device, torch.long)
+    # Create datasets with transformations
+    train_dataset = CustomDataset(x_train, y_train, transform=train_transform)
+    val_dataset = CustomDataset(x_val, y_val, transform=test_val_transform)
+    test_dataset = CustomDataset(x_test, y_test, transform=test_val_transform)
 
-    x_test_tensor = to_device_tensor(x_test, device, torch.float32)
-    y_test_tensor = to_device_tensor(y_test, device, torch.long)
+    # Create DataLoaders
+    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
+    validation_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
+    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
 
-    train_loader = DataLoader(TensorDataset(x_train_tensor, y_train_tensor), batch_size=128, shuffle=True)
-    validation_loader = DataLoader(TensorDataset(x_val_tensor, y_val_tensor), batch_size=128, shuffle=True)
-    test_loader = DataLoader( TensorDataset(x_test_tensor, y_test_tensor), batch_size=128, shuffle = True)    
+    # Store DataLoaders in the client object
+    setattr(client, 'data_loader', {'train': train_loader, 'val': validation_loader, 'test': test_loader})
+    setattr(client, 'train_test', {'x_train': x_train, 'x_val': x_val, 'x_test': x_test, 'y_train': y_train, 'y_val': y_val, 'y_test': y_test})
+
+    return
 
-    setattr(client, 'data_loader', {'train' : train_loader, 'val' : validation_loader, 'test': test_loader, })
-    setattr(client,'train_test', {'x_train': x_train, 'x_val' : x_val, 'x_test': x_test, 'y_train': y_train,  'y_val': y_val, 'y_test': y_test})
-    
-    return 
 
 
 def get_dataset_heterogeneities(heterogeneity_type: str) -> dict:
@@ -278,7 +307,7 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
 
     """
 
-    from src.models import GenericConvModel
+    from src.models import GenericConvModel, CovNet
     from src.utils_fed import init_server_cluster
     import torch
     
@@ -297,6 +326,9 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
     elif row_exp['nn_model'] == "convolutional": 
         
         model_server = Server(GenericConvModel(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1]))
+    elif row_exp['nn_model'] == "CovNet":
+        model_server = Server(CovNet(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1]))
+       
 
     model_server.model.to(device)
 
@@ -594,7 +626,8 @@ def centralize_data(list_clients : list) -> Tuple[DataLoader, DataLoader]:
     import torch 
     from torch.utils.data import DataLoader,TensorDataset
     import numpy as np 
-
+    
+    
     x_train = np.concatenate([list_clients[id].train_test['x_train'] for id in range(len(list_clients))],axis = 0)
     y_train = np.concatenate([list_clients[id].train_test['y_train'] for id in range(len(list_clients))],axis = 0)
     x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
@@ -616,6 +649,58 @@ def centralize_data(list_clients : list) -> Tuple[DataLoader, DataLoader]:
     
     return train_loader, val_loader, test_loader
 
+def centralize_data(list_clients: list) -> Tuple[DataLoader, DataLoader]:
+    """Centralize data of the federated learning setup for central model comparison
+
+    Arguments:
+        list_clients : The list of Client Objects
+
+    Returns:
+        Train and test torch DataLoaders with data of all Clients
+    """
+    
+    from torchvision import transforms
+    import torch 
+    from torch.utils.data import DataLoader,TensorDataset
+    import numpy as np 
+
+    train_transform = transforms.Compose([
+        transforms.RandomHorizontalFlip(),
+        AddGaussianNoise(0., 1.),
+        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
+        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize if needed
+    ])
+    
+    # Transform for validation and test data (no augmentation, just normalization)
+    test_val_transform = transforms.Compose([
+        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize if needed
+    ])
+
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    
+    # Concatenate training data from all clients
+    x_train = np.concatenate([list_clients[id].train_test['x_train'] for id in range(len(list_clients))], axis=0)
+    y_train = np.concatenate([list_clients[id].train_test['y_train'] for id in range(len(list_clients))], axis=0)
+
+    # Concatenate validation data from all clients
+    x_val = np.concatenate([list_clients[id].train_test['x_val'] for id in range(len(list_clients))], axis=0)
+    y_val = np.concatenate([list_clients[id].train_test['y_val'] for id in range(len(list_clients))], axis=0)
+
+    # Concatenate test data from all clients
+    x_test = np.concatenate([list_clients[id].train_test['x_test'] for id in range(len(list_clients))], axis=0)
+    y_test = np.concatenate([list_clients[id].train_test['y_test'] for id in range(len(list_clients))], axis=0)
+
+    # Create Custom Datasets
+    train_dataset = CustomDataset(x_train, y_train, transform=train_transform)
+    val_dataset = CustomDataset(x_val, y_val, transform=test_val_transform)
+    test_dataset = CustomDataset(x_test, y_test, transform=test_val_transform)
+
+    # Create DataLoaders
+    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
+    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)  # Validation typically not shuffled
+    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)  # Test data typically not shuffled
+
+    return train_loader, val_loader, test_loader
 
 
 
diff --git a/src/utils_fed.py b/src/utils_fed.py
index 7574d3ef21cad8e3c0ba75567bb551c831400b39..345bf7c0df90b9f9dca50046c07ba3f1c9555d9e 100644
--- a/src/utils_fed.py
+++ b/src/utils_fed.py
@@ -1,8 +1,11 @@
 from src.fedclass import Server
+import torch
 import torch.nn as nn
 import pandas as pd
 from torch.utils.data import DataLoader
 
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
 def send_server_model_to_client(list_clients : list, my_server : Server) -> None:
     
     """ Function to copy the Server model to client attributes in a FL protocol
@@ -120,12 +123,12 @@ def model_weight_matrix(list_clients : list) -> pd.DataFrame:
 
     model_dict = {client.id : client.model for client in list_clients}
 
-    shapes = [param.data.numpy().shape for param in next(iter(model_dict.values())).parameters()]
+    shapes = [param.data.cpu().numpy().shape for param in next(iter(model_dict.values())).parameters()]
     weight_matrix_np = np.empty((len(model_dict), sum(np.prod(shape) for shape in shapes)))
 
     for idx, (_, model) in enumerate(model_dict.items()):
 
-        model_weights = np.concatenate([param.data.numpy().flatten() for param in model.parameters()])
+        model_weights = np.concatenate([param.data.cpu().numpy().flatten() for param in model.parameters()])
         weight_matrix_np[idx, :] = model_weights
 
     weight_matrix = pd.DataFrame(weight_matrix_np, columns=[f'w_{i+1}' for i in range(weight_matrix_np.shape[1])])
@@ -196,7 +199,7 @@ def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict,
         p_expert_opintion : Parameter to avoid completly random assignment if neeed (default to 0)
     """
     
-    from src.models import GenericConvModel
+    from src.models import GenericLinearModel, GenericConvModel, CovNet
     import numpy as np
     import copy
 
@@ -211,10 +214,8 @@ def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict,
     p_rest = (1 - p_expert_opinion) / (row_exp['num_clusters'] - 1)
 
     my_server.num_clusters = row_exp['num_clusters']
-
     my_server.clusters_models = {cluster_id: GenericConvModel(in_size=imgs_params[0], n_channels=imgs_params[1]) for cluster_id in range(row_exp['num_clusters'])}
     
-    
     for client in list_clients:
     
         probs = [p_rest if x != list_heterogeneities.index(client.heterogeneity_class) % row_exp['num_clusters']
@@ -241,7 +242,8 @@ def loss_calculation(model : nn.modules, train_loader : DataLoader) -> float:
     import torch.nn as nn
 
     criterion = nn.CrossEntropyLoss()  
-
+    
+    model.to(device)
     model.eval()
 
     total_loss = 0.0
diff --git a/tests/refs/client_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv b/tests/refs/client_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv
index a16b9a7b1c01808b7b327a513b116cc174d91d40..709b3aefdd09492937e94b9c9fb7f686a2c5ecb5 100644
--- a/tests/refs/client_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv
+++ b/tests/refs/client_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv
@@ -1,7 +1,7 @@
-id,cluster_id,heterogeneity_class,accuracy
-0,0,none,82.0
-1,0,erosion,56.666666666666664
-2,0,dilatation,82.33333333333334
-3,0,none,83.66666666666667
-4,0,erosion,57.99999999999999
-5,0,dilatation,83.0
+id,cluster_id,heterogeneity_class,accuracy
+0,0,none,82.0
+1,0,erosion,56.666666666666664
+2,0,dilatation,82.33333333333334
+3,0,none,83.66666666666667
+4,0,erosion,57.99999999999999
+5,0,dilatation,83.0
diff --git a/tests/refs/global-federated_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv b/tests/refs/global-federated_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv
index 15b985617ada75f8f965f7d41fe75e6be6d335e9..07d5a6495b5998e98726ee296857516205735cbf 100644
--- a/tests/refs/global-federated_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv
+++ b/tests/refs/global-federated_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv
@@ -1,7 +1,7 @@
-id,cluster_id,heterogeneity_class,accuracy
-0,,none,74.27777777777777
-1,,erosion,74.27777777777777
-2,,dilatation,74.27777777777777
-3,,none,74.27777777777777
-4,,erosion,74.27777777777777
-5,,dilatation,74.27777777777777
+id,cluster_id,heterogeneity_class,accuracy
+0,,none,74.27777777777777
+1,,erosion,74.27777777777777
+2,,dilatation,74.27777777777777
+3,,none,74.27777777777777
+4,,erosion,74.27777777777777
+5,,dilatation,74.27777777777777
diff --git a/tests/refs/pers-centralized_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv b/tests/refs/pers-centralized_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv
index 6c0dd9d7fcefe75841ede1c4e5c571c004d439c8..5bcc37ae272a05b2ed913b8036f40c7f9e4a4bee 100644
--- a/tests/refs/pers-centralized_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv
+++ b/tests/refs/pers-centralized_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv
@@ -1,7 +1,7 @@
-id,cluster_id,heterogeneity_class,accuracy
-0,,none,65.33333333333333
-1,,erosion,39.5
-2,,dilatation,79.0
-3,,none,65.33333333333333
-4,,erosion,39.5
-5,,dilatation,79.0
+id,cluster_id,heterogeneity_class,accuracy
+0,,none,65.33333333333333
+1,,erosion,39.5
+2,,dilatation,79.0
+3,,none,65.33333333333333
+4,,erosion,39.5
+5,,dilatation,79.0
diff --git a/tests/refs/server_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv b/tests/refs/server_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv
index 82214a8a5a8c576aa652565a94570f7ce4fd7dd5..f3337b1b436a1e00874995dd2ad53d4e4a685e98 100644
--- a/tests/refs/server_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv
+++ b/tests/refs/server_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv
@@ -1,7 +1,7 @@
-id,cluster_id,heterogeneity_class,accuracy
-0,2,none,85.66666666666667
-1,1,erosion,66.66666666666666
-2,0,dilatation,86.66666666666667
-3,2,none,88.66666666666667
-4,1,erosion,68.66666666666667
-5,0,dilatation,89.0
+id,cluster_id,heterogeneity_class,accuracy
+0,2,none,85.66666666666667
+1,1,erosion,66.66666666666666
+2,0,dilatation,86.66666666666667
+3,2,none,88.66666666666667
+4,1,erosion,68.66666666666667
+5,0,dilatation,89.0
diff --git a/tests/test_utils_training.py b/tests/test_utils_training.py
index 5082727982a113f7ece42ed6ac9d204c9698dfa5..d8e91006fc7b6697d35ca8d20a06c18ea553e6cb 100644
--- a/tests/test_utils_training.py
+++ b/tests/test_utils_training.py
@@ -1,123 +1,123 @@
-import os
-import pytest
-
-from pathlib import Path
-
-if os.getenv('_PYTEST_RAISE', "0") != "0":
-
-    @pytest.hookimpl(tryfirst=True)
-    def pytest_exception_interact(call):
-        raise call.excinfo.value
-
-    @pytest.hookimpl(tryfirst=True)
-    def pytest_internalerror(excinfo):
-        raise excinfo.value
-
-
-def utils_extract_params(file_path: Path):
-    """  Creates a dictionary row_exp with the parameters for the experiment given a well formated results file path
-    """
-
-    with open (file_path, "r") as fp:
-        
-        keys = ['exp_type', 'dataset', 'nn_model', 'heterogeneity_type' , 'num_clients',
-                'num_samples_by_label' , 'num_clusters', 'centralized_epochs',
-                'federated_rounds', 'seed']
-        
-
-        parameters = file_path.stem.split('_')
-
-        row_exp = dict(
-            zip(keys,
-                parameters[:4] + [int(x) for x in  parameters[4:]])
-            )
-    
-    return row_exp
-
-
-def test_run_cfl_benchmark_oracle():
-
-    from pathlib import Path
-    import numpy as np
-    import pandas as pd
-
-    from src.utils_data import setup_experiment    
-    from src.utils_training import run_benchmark
-
-    file_path = Path("tests/refs/pers-centralized_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
-
-    row_exp = utils_extract_params(file_path) 
-   
-    model_server, list_clients = setup_experiment(row_exp)
-
-    df_results = run_benchmark(model_server, list_clients, row_exp)
-
-    assert all(np.isclose(df_results['accuracy'], pd.read_csv(file_path)['accuracy'], rtol=0.01))
-
-
-def test_run_cfl_benchmark_fl():
-
-    from pathlib import Path
-    import numpy as np
-    import pandas as pd
-
-    from src.utils_data import setup_experiment    
-    from src.utils_training import run_benchmark
-
-    file_path = Path("tests/refs/global-federated_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
-
-    row_exp = utils_extract_params(file_path) 
-   
-    model_server, list_clients = setup_experiment(row_exp)
-
-    df_results = run_benchmark(model_server, list_clients, row_exp)
-
-    assert all(np.isclose(df_results['accuracy'], pd.read_csv(file_path)['accuracy'], rtol=0.01))
-
-
-def test_run_cfl_client_side():
-
-    from pathlib import Path
-    import numpy as np
-    import pandas as pd
-
-    from src.utils_data import setup_experiment    
-    from src.utils_training import run_cfl_client_side
-
-    file_path = Path("tests/refs/client_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
-
-    row_exp = utils_extract_params(file_path) 
-   
-    model_server, list_clients = setup_experiment(row_exp)
-
-    df_results =  run_cfl_client_side(model_server, list_clients, row_exp)
-
-    assert all(np.isclose(df_results['accuracy'], pd.read_csv(file_path)['accuracy'], rtol=0.01))
-
-
-def test_run_cfl_server_side():
-
-    from pathlib import Path
-    import numpy as np
-    import pandas as pd
-
-    from src.utils_data import setup_experiment    
-    from src.utils_training import run_cfl_server_side
-
-    file_path = Path("tests/refs/server_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
-
-    row_exp = utils_extract_params(file_path) 
-   
-    model_server, list_clients = setup_experiment(row_exp)
-
-    df_results =  run_cfl_server_side(model_server, list_clients, row_exp)
-
-    assert all(np.isclose(df_results['accuracy'], pd.read_csv(file_path)['accuracy'], rtol=0.01))
-
-
-if __name__ == "__main__":
-    test_run_cfl_client_side()
-    test_run_cfl_server_side()
-    test_run_cfl_benchmark_fl()
-    test_run_cfl_benchmark_oracle()
+import os
+import pytest
+
+from pathlib import Path
+
+if os.getenv('_PYTEST_RAISE', "0") != "0":
+
+    @pytest.hookimpl(tryfirst=True)
+    def pytest_exception_interact(call):
+        raise call.excinfo.value
+
+    @pytest.hookimpl(tryfirst=True)
+    def pytest_internalerror(excinfo):
+        raise excinfo.value
+
+
+def utils_extract_params(file_path: Path):
+    """  Creates a dictionary row_exp with the parameters for the experiment given a well formated results file path
+    """
+
+    with open (file_path, "r") as fp:
+        
+        keys = ['exp_type', 'dataset', 'nn_model', 'heterogeneity_type' , 'num_clients',
+                'num_samples_by_label' , 'num_clusters', 'centralized_epochs',
+                'federated_rounds', 'seed']
+        
+
+        parameters = file_path.stem.split('_')
+
+        row_exp = dict(
+            zip(keys,
+                parameters[:4] + [int(x) for x in  parameters[4:]])
+            )
+    
+    return row_exp
+
+
+def test_run_cfl_benchmark_oracle():
+
+    from pathlib import Path
+    import numpy as np
+    import pandas as pd
+
+    from src.utils_data import setup_experiment    
+    from src.utils_training import run_benchmark
+
+    file_path = Path("tests/refs/pers-centralized_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
+
+    row_exp = utils_extract_params(file_path) 
+   
+    model_server, list_clients = setup_experiment(row_exp)
+
+    df_results = run_benchmark(model_server, list_clients, row_exp)
+
+    assert all(np.isclose(df_results['accuracy'], pd.read_csv(file_path)['accuracy'], rtol=0.01))
+
+
+def test_run_cfl_benchmark_fl():
+
+    from pathlib import Path
+    import numpy as np
+    import pandas as pd
+
+    from src.utils_data import setup_experiment    
+    from src.utils_training import run_benchmark
+
+    file_path = Path("tests/refs/global-federated_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
+
+    row_exp = utils_extract_params(file_path) 
+   
+    model_server, list_clients = setup_experiment(row_exp)
+
+    df_results = run_benchmark(model_server, list_clients, row_exp)
+
+    assert all(np.isclose(df_results['accuracy'], pd.read_csv(file_path)['accuracy'], rtol=0.01))
+
+
+def test_run_cfl_client_side():
+
+    from pathlib import Path
+    import numpy as np
+    import pandas as pd
+
+    from src.utils_data import setup_experiment    
+    from src.utils_training import run_cfl_client_side
+
+    file_path = Path("tests/refs/client_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
+
+    row_exp = utils_extract_params(file_path) 
+   
+    model_server, list_clients = setup_experiment(row_exp)
+
+    df_results =  run_cfl_client_side(model_server, list_clients, row_exp)
+
+    assert all(np.isclose(df_results['accuracy'], pd.read_csv(file_path)['accuracy'], rtol=0.01))
+
+
+def test_run_cfl_server_side():
+
+    from pathlib import Path
+    import numpy as np
+    import pandas as pd
+
+    from src.utils_data import setup_experiment    
+    from src.utils_training import run_cfl_server_side
+
+    file_path = Path("tests/refs/server_fashion-mnist_linear_features-distribution-skew_8_100_3_5_5_42.csv")
+
+    row_exp = utils_extract_params(file_path) 
+   
+    model_server, list_clients = setup_experiment(row_exp)
+
+    df_results =  run_cfl_server_side(model_server, list_clients, row_exp)
+
+    assert all(np.isclose(df_results['accuracy'], pd.read_csv(file_path)['accuracy'], rtol=0.01))
+
+
+if __name__ == "__main__":
+    test_run_cfl_client_side()
+    test_run_cfl_server_side()
+    test_run_cfl_benchmark_fl()
+    test_run_cfl_benchmark_oracle()
     
\ No newline at end of file