diff --git a/src/utils_data.py b/src/utils_data.py
index 1a0058d33746eb509b0bd323878436af448acb3b..db1d33e18bc3ee207b5f7c1b1957a5b3306d2a0c 100644
--- a/src/utils_data.py
+++ b/src/utils_data.py
@@ -52,22 +52,21 @@ def create_label_dict(dataset : str, nn_model : str) -> dict:
     if dataset == "fashion-mnist":
         fashion_mnist = torchvision.datasets.MNIST("datasets", download=True)
         (x_data, y_data) = fashion_mnist.data, fashion_mnist.targets
-    
-        if nn_model in ["convolutional"]:
-            x_data = x_data.unsqueeze(1)
-
+        x_data = x_data.unsqueeze(3)
     elif dataset == 'mnist':
         mnist = torchvision.datasets.MNIST("datasets", download=True)
         (x_data, y_data) = mnist.data, mnist.targets
-    
+        x_data = x_data.unsqueeze(3)
     elif dataset == "cifar10":
         cifar10 = torchvision.datasets.CIFAR10("datasets", download=True)
         (x_data, y_data) = cifar10.data, cifar10.targets
-        
-        
+         
     elif dataset == 'kmnist':
         kmnist = torchvision.datasets.KMNIST("datasets", download=True)
-        (x_data, y_data)  = kmnist.load_data()    
+        x_data = kmnist.data  # This gives you the images
+        x_data = x_data.unsqueeze(3)
+        y_data = kmnist.targets  # This gives you the labels  
+         
     else:
         sys.exit("Unrecognized dataset. Please make sure you are using one of the following ['mnist', fashion-mnist', 'kmnist']")    
 
@@ -184,7 +183,7 @@ class AddRandomJitter(object):
                                saturation = self.saturation, hue = self.hue)
         return transform(tensor)
 
-class CustomDataset(Dataset):
+class CifarDataset(Dataset):
     
     def __init__(self, data, labels, transform=None):
         # Ensure data is in (N, H, W, C) format
@@ -229,18 +228,26 @@ def data_preparation(client: Client, row_exp: dict) -> None:
     import numpy as np  # Import NumPy for transpose operation
     
     # Define data augmentation transforms
-
-    train_transform = transforms.Compose([
+    if row_exp['dataset'] == 'cifar10': 
+        train_transform = transforms.Compose([
         transforms.RandomHorizontalFlip(),
         transforms.RandomRotation(20),  # Normalize if needed
         transforms.RandomCrop(32, padding=4),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
-    ])
-    
-    # Transform for validation and test data (no augmentation, just normalization)
-    test_val_transform = transforms.Compose([
+        ])
+        # Transform for validation and test data (no augmentation, just normalization)
+        test_val_transform = transforms.Compose([
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize if needed
     ])
+    else : 
+        train_transform = transforms.Compose([
+        transforms.Normalize((0.5,), (0.5,)),  # Normalize if needed
+    ])
+        
+        # Transform for validation and test data (no augmentation, just normalization)
+        test_val_transform = transforms.Compose([
+            transforms.Normalize((0.5,), (0.5,)),  # Normalize if needed
+        ])
 
     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     
@@ -250,9 +257,9 @@ def data_preparation(client: Client, row_exp: dict) -> None:
 
 
     # Create datasets with transformations
-    train_dataset = CustomDataset(x_train, y_train, transform=train_transform)
-    val_dataset = CustomDataset(x_val, y_val, transform=test_val_transform)
-    test_dataset = CustomDataset(x_test, y_test, transform=test_val_transform)
+    train_dataset = CifarDataset(x_train, y_train, transform=train_transform)
+    val_dataset = CifarDataset(x_val, y_val, transform=test_val_transform)
+    test_dataset = CifarDataset(x_test, y_test, transform=test_val_transform)
 
     # Create DataLoaders
     train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
@@ -306,7 +313,7 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
 
     """
 
-    from src.models import GenericConvModel
+    from src.models import GenericConvModel,GenericLinearModel
     from src.utils_fed import init_server_cluster
     import torch
     
@@ -320,7 +327,7 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
 
     if row_exp['nn_model'] == "linear":
         
-        model_server = Server(GenericConvModel(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) 
+        model_server = Server(GenericLinearModel(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) 
     
     elif row_exp['nn_model'] == "convolutional": 
         
@@ -651,9 +658,9 @@ def centralize_data(list_clients: list) -> Tuple[DataLoader, DataLoader]:
     y_test = np.concatenate([list_clients[id].train_test['y_test'] for id in range(len(list_clients))], axis=0)
 
     # Create Custom Datasets
-    train_dataset = CustomDataset(x_train, y_train, transform=train_transform)
-    val_dataset = CustomDataset(x_val, y_val, transform=test_val_transform)
-    test_dataset = CustomDataset(x_test, y_test, transform=test_val_transform)
+    train_dataset = CifarDataset(x_train, y_train, transform=train_transform)
+    val_dataset = CifarDataset(x_val, y_val, transform=test_val_transform)
+    test_dataset = CifarDataset(x_test, y_test, transform=test_val_transform)
 
     # Create DataLoaders
     train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)