diff --git a/.gitignore b/.gitignore
index afa9b70f86aff397f5ddd3837824540602e8642e..557207b998a02ae4e19ee855a069b4f1d97c3a6e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,4 +9,5 @@ backup_results/*
 *.sh
 tests/__pycache__/*
 datasets/*
-pub/*
\ No newline at end of file
+pub/*
+data/*
diff --git a/driver.py b/driver.py
index e28c4227edf8afa08454adab8dd8d5ac34b15c02..9332a4cb0671b019de3fd7c9a1857785b6e6d42a 100644
--- a/driver.py
+++ b/driver.py
@@ -28,7 +28,7 @@ def main_driver(exp_type, dataset, nn_model, heterogeneity_type, num_clients, nu
     
 
     output_name =  row_exp.to_string(header=False, index=False, name=False).replace(' ', "").replace('\n','_')
-
+    
     hash_outputname = get_uid(output_name)
 
     pathlist = Path("results").rglob('*.json')
@@ -47,7 +47,6 @@ def main_driver(exp_type, dataset, nn_model, heterogeneity_type, num_clients, nu
     except Exception as e:
 
         print(f"Could not run experiment with parameters {output_name}. Exception {e}")
-
         return 
     
     launch_experiment(model_server, list_clients, row_exp, output_name)
@@ -94,7 +93,5 @@ def launch_experiment(model_server, list_clients, row_exp, output_name, save_res
         return
 
 
-
-
 if __name__ == "__main__":
     main_driver()
diff --git a/src/models.py b/src/models.py
index 3acd1e15c7474661de12c4efb8b68abe8e336422..78387938fac081a9759143cae170198c354f5c47 100644
--- a/src/models.py
+++ b/src/models.py
@@ -3,86 +3,89 @@ import torch.nn as nn
 import torch.nn.functional as F
 
 
-class SimpleLinear(nn.Module):
-    """ Fully connected neural network with a single hidden layer of default size 200 and ReLU activations"""
+def accuracy(outputs, labels):
+    _, preds = torch.max(outputs, dim=1)
+    return torch.tensor(torch.sum(preds == labels).item() / len(preds))
+
+
+class ImageClassificationBase(nn.Module):
+    def training_step(self, batch):
+        images, labels = batch 
+        out = self(images)
+        loss = F.cross_entropy(out, labels) # Calculate loss
+        return loss
+    
+    def validation_step(self, batch):
+        images, labels = batch 
+        out = self(images)
+        loss = F.cross_entropy(out, labels)   # Calculate loss
+        acc = accuracy(out, labels)
+        return {'val_loss': loss.detach(), 'val_acc': acc}
+        
+    def validation_epoch_end(self, outputs):
+        batch_losses = [x['val_loss'] for x in outputs]
+        epoch_loss = torch.stack(batch_losses).mean()
+        batch_accs = [x['val_acc'] for x in outputs]
+        epoch_acc = torch.stack(batch_accs).mean()
+        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
+    
+    def epoch_end(self, epoch, result):
+        print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
+            epoch, result['train_loss'], result['val_loss'], result['val_acc']))
+        
+
+
+class GenericLinearModel(ImageClassificationBase):
     
     def __init__(self, in_size, n_channels):
         
-        """ Initialization function
-        Arguments:
-            h1: int
-                Desired size of the hidden layer 
-        """
         super().__init__()
-        self.fc1 = nn.Linear(in_size*in_size,200)
-        self.fc2 = nn.Linear(200, 10)
+        
         self.in_size = in_size
 
-    def forward(self, x: torch.Tensor):
+        self.network = nn.Sequential(
+            nn.Linear(in_size*in_size,200),
+            nn.Linear(200, 10))
         
-        """ Forward pass function through the network
-        
-        Arguments:
-            x : torch.Tensor
-                input image of size in_size x in_size
-
-        Returns: 
-            log_softmax probabilities of the output layer
-        """
+    def forward(self, xb):
+        xb = xb.view(-1, self.in_size * self.in_size)
+        return self.network(xb)
         
-        x = x.view(-1, self.in_size * self.in_size)
-        x = F.relu(self.fc1(x))
-        x = self.fc2(x)
-        return F.log_softmax(x, dim=1)
     
 
-class SimpleConv(nn.Module):
 
-    """ Convolutional neural network with 3 convolutional layers and one fully connected layer
-    """
-
-    def __init__(self,  in_size, n_channels):
-        """ Initialization function
-        """
-        super(SimpleConv, self).__init__()
-                
-        self.conv1 = nn.Conv2d(n_channels, 16, 3, padding=1)
-        self.conv2 = nn.Conv2d(16, 32, 3,  padding=1)
-        self.conv3 = nn.Conv2d(32, 16, 3,  padding=1)
-        
-        self.img_final_size = int(in_size / 8)
+class GenericConvModel(ImageClassificationBase):
+    def __init__(self, in_size, n_channels):
+        super().__init__()
         
-        self.fc1 = nn.Linear(16 * self.img_final_size * self.img_final_size, 10)
-
-        self.pool = nn.MaxPool2d(2, 2)
+        self.img_final_size = int(in_size / (2**3))
 
-        self.dropout = nn.Dropout(p=0.2)
+        self.network = nn.Sequential(
+            nn.Conv2d(n_channels, 32, kernel_size=3, padding=1),
+            nn.ReLU(),
+            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
+            nn.ReLU(),
+            nn.MaxPool2d(2, 2), # output: 64 x 16 x 16
 
-    def flatten(self, x : torch.Tensor):
-    
-        """Function to flatten a layer
-        
-            Arguments: 
-                x : torch.Tensor
+            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
+            nn.ReLU(),
+            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
+            nn.ReLU(),
+            nn.MaxPool2d(2, 2), # output: 128 x 8 x 8
 
-            Returns:
-                flattened Tensor
-        """
-    
-        return x.reshape(x.size()[0], -1)
-    
-    def forward(self, x : torch.Tensor):
-        """ Forward pass through the network which returns the softmax probabilities of the output layer
+            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
+            nn.ReLU(),
+            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
+            nn.ReLU(),
+            nn.MaxPool2d(2, 2), # output: 256 x 4 x 4
 
-        Arguments:
-            x : torch.Tensor
-                input image to use for training
-        """
+            nn.Flatten(), 
+            nn.Linear(256 * self.img_final_size * self.img_final_size, 1024),
+            nn.ReLU(),
+            nn.Linear(1024, 512),
+            nn.ReLU(),
+            nn.Linear(512, 10))
         
-        x = self.dropout(self.pool(F.relu(self.conv1(x))))
-        x = self.dropout(self.pool(F.relu(self.conv2(x))))
-        x = self.dropout(self.pool(F.relu(self.conv3(x))))
-        x = self.flatten(x)
-        x = self.fc1(x)
-
-        return F.log_softmax(x, dim=1)
\ No newline at end of file
+    def forward(self, xb):
+        return self.network(xb)
+    
\ No newline at end of file
diff --git a/src/utils_data.py b/src/utils_data.py
index 8cee10f92ea6cafb0c6f778fa0e3058832c5da4d..eea61c59be29003b17ae65f354064dd394628076 100644
--- a/src/utils_data.py
+++ b/src/utils_data.py
@@ -47,31 +47,37 @@ def create_label_dict(dataset : str, nn_model : str) -> dict:
     import numpy as np
     import torchvision
     from extra_keras_datasets import kmnist
-    
+    import torchvision.transforms as transforms
+
+    transform = transforms.Compose(
+    [transforms.ToTensor(),
+     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+
     if dataset == "fashion-mnist":
-        fashion_mnist = torchvision.datasets.MNIST("datasets", download=True)
-        (x_train, y_train) = fashion_mnist.data, fashion_mnist.targets
+        fashion_mnist = torchvision.datasets.MNIST("datasets", download=True, transform=transform)
+        (x_data, y_data) = fashion_mnist.data, fashion_mnist.targets
     
         if nn_model == "convolutional":
-            x_train = x_train.unsqueeze(1)
+            x_data = x_data.unsqueeze(1)
 
     elif dataset == 'mnist':
         mnist = torchvision.datasets.MNIST("datasets", download=True)
-        (x_train, y_train) = mnist.data, mnist.targets
+        (x_data, y_data) = mnist.data, mnist.targets
         
         if nn_model == "convolutional":
-            x_train = x_train.unsqueeze(1)
+            x_data = x_data.unsqueeze(1)
 
     elif dataset == "cifar10":
-        cifar10 = torchvision.datasets.CIFAR10("datasets", download=True)
-        (x_train, y_train) = cifar10.data, cifar10.targets
-        x_train = np.transpose(x_train, (0, 3, 1, 2))
-
+        cifar10 = torchvision.datasets.CIFAR10("datasets", download=True, transform=transform)
+        (x_data, y_data) = cifar10.data, cifar10.targets
+        x_data = np.transpose(x_data, (0, 3, 1, 2))
+        
     elif dataset == 'kmnist':
-        (x_train, y_train), _ = kmnist.load_data()
+        kmnist = torchvision.datasets.KMNIST("datasets", download=True, transform=transform)
+        (x_data, y_data)  = kmnist.load_data()
 
         if nn_model == "convolutional":
-            x_train = x_train.unsqueeze(1)
+            x_data = x_data.unsqueeze(1)
     
     else:
         sys.exit("Unrecognized dataset. Please make sure you are using one of the following ['mnist', fashion-mnist', 'kmnist']")    
@@ -80,8 +86,8 @@ def create_label_dict(dataset : str, nn_model : str) -> dict:
 
     for label in range(10):
        
-        label_indices = np.where(np.array(y_train) == label)[0]   
-        label_samples_x = x_train[label_indices]
+        label_indices = np.where(np.array(y_data) == label)[0]   
+        label_samples_x = x_data[label_indices]
         label_dict[label] = label_samples_x
         
     return label_dict
@@ -138,7 +144,6 @@ def rotate_images(client: Client, rotation: int) -> None:
     """
     
     import numpy as np
-    from math import prod
 
     images = client.data['x']
 
@@ -149,11 +154,8 @@ def rotate_images(client: Client, rotation: int) -> None:
         for img in images:
     
             orig_shape = img.shape             
-            img_flatten = img.flatten()
-
-            rotated_img = np.rot90(img, k=rotation//90)  # Rotate image by specified angle
+            rotated_img = np.rot90(img, k=rotation//90)  # Rotate image by specified angle 
             rotated_img = rotated_img.reshape(*orig_shape)
-
             rotated_images.append(rotated_img)   
     
         client.data['x'] = np.array(rotated_images)
@@ -171,37 +173,36 @@ def data_preparation(client : Client, row_exp : dict) -> None:
         row_exp : The current experiment's global parameters
     """
 
+    def to_device_tensor(data, device, data_dtype):
+    
+        data = torch.tensor(data, dtype=data_dtype)
+        data.to(device)
+        return data
+    
     import torch 
     from sklearn.model_selection import train_test_split
     from torch.utils.data import DataLoader, TensorDataset
 
     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     
-    x_train, x_test, y_train, y_test = train_test_split(client.data['x'], client.data['y'], test_size=0.3, random_state=row_exp['seed'],stratify=client.data['y'])
-
-    x_train, x_test = x_train/255.0 , x_test/255.0
-
-    x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
-    x_train_tensor.to(device)
-
+    x_data, x_test, y_data, y_test = train_test_split(client.data['x'], client.data['y'], test_size=0.3, random_state=row_exp['seed'],stratify=client.data['y'])
+    x_train, x_val, y_train, y_val  = train_test_split(x_data, y_data, test_size=0.25, random_state=42) 
 
-    y_train_tensor = torch.tensor(y_train, dtype=torch.long)
-    y_train_tensor.to(device)
+    x_train_tensor = to_device_tensor(x_train, device, torch.float32)
+    y_train_tensor = to_device_tensor(y_train, device, torch.long)
 
-    x_test_tensor = torch.tensor(x_test, dtype=torch.float32)
-    x_test_tensor.to(device)
-    y_test_tensor = torch.tensor(y_test, dtype=torch.long)
-    y_test_tensor.to(device)
+    x_val_tensor = to_device_tensor(x_val, device, torch.float32)
+    y_val_tensor = to_device_tensor(y_val, device, torch.long)
 
+    x_test_tensor = to_device_tensor(x_test, device, torch.float32)
+    y_test_tensor = to_device_tensor(y_test, device, torch.long)
 
-    train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
-    train_loader = DataLoader(train_dataset, batch_size=32)
-    
-    test_dataset = TensorDataset(x_test_tensor, y_test_tensor)
-    test_loader = DataLoader(test_dataset, batch_size=32)    
+    train_loader = DataLoader(TensorDataset(x_train_tensor, y_train_tensor), batch_size=128, shuffle=True)
+    validation_loader = DataLoader(TensorDataset(x_val_tensor, y_val_tensor), batch_size=128, shuffle=True)
+    test_loader = DataLoader( TensorDataset(x_test_tensor, y_test_tensor), batch_size=128, shuffle = True)    
 
-    setattr(client, 'data_loader', {'train' : train_loader,'test': test_loader})
-    setattr(client,'train_test', {'x_train': x_train,'x_test': x_test, 'y_train': y_train, 'y_test': y_test})
+    setattr(client, 'data_loader', {'train' : train_loader, 'val' : validation_loader, 'test': test_loader, })
+    setattr(client,'train_test', {'x_train': x_train, 'x_val' : x_val, 'x_test': x_test, 'y_train': y_train,  'y_val': y_val, 'y_test': y_test})
     
     return 
 
@@ -245,7 +246,7 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
 
     """
 
-    from src.models import SimpleLinear, SimpleConv
+    from src.models import GenericLinearModel, GenericConvModel
     from src.utils_fed import init_server_cluster
     import torch
     
@@ -257,11 +258,11 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
 
     if row_exp['nn_model'] == "linear":
         
-        model_server = Server(SimpleLinear(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) 
+        model_server = Server(GenericLinearModel(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) 
     
     elif row_exp['nn_model'] == "convolutional": 
         
-        model_server = Server(SimpleConv(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1]))
+        model_server = Server(GenericConvModel(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1]))
 
     dict_clients = get_clients_data(row_exp['num_clients'],
                                     row_exp['num_samples_by_label'],
@@ -559,24 +560,25 @@ def centralize_data(list_clients : list) -> Tuple[DataLoader, DataLoader]:
     import numpy as np 
 
     x_train = np.concatenate([list_clients[id].train_test['x_train'] for id in range(len(list_clients))],axis = 0)
-    x_test = np.concatenate([list_clients[id].train_test['x_test'] for id in range(len(list_clients))],axis = 0)
-    
     y_train = np.concatenate([list_clients[id].train_test['y_train'] for id in range(len(list_clients))],axis = 0)
-    y_test = np.concatenate([list_clients[id].train_test['y_test'] for id in range(len(list_clients))],axis = 0)
-    
     x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
     y_train_tensor = torch.tensor(y_train, dtype=torch.long)
-    
+
+    x_val = np.concatenate([list_clients[id].train_test['x_val'] for id in range(len(list_clients))],axis = 0)
+    y_val = np.concatenate([list_clients[id].train_test['y_val'] for id in range(len(list_clients))],axis = 0)
+    x_val_tensor = torch.tensor(x_val, dtype=torch.float32)
+    y_val_tensor = torch.tensor(y_val, dtype=torch.long)
+
+    x_test = np.concatenate([list_clients[id].train_test['x_test'] for id in range(len(list_clients))],axis = 0)
+    y_test = np.concatenate([list_clients[id].train_test['y_test'] for id in range(len(list_clients))],axis = 0)
     x_test_tensor = torch.tensor(x_test, dtype=torch.float32)
     y_test_tensor = torch.tensor(y_test, dtype=torch.long)
     
-    train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
-    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
-    
-    test_dataset = TensorDataset(x_test_tensor, y_test_tensor)
-    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
+    train_loader = DataLoader(TensorDataset(x_train_tensor, y_train_tensor), batch_size=64, shuffle=True)
+    val_loader = DataLoader(TensorDataset(x_val_tensor, y_val_tensor), batch_size=64, shuffle=True)
+    test_loader = DataLoader(TensorDataset(x_test_tensor, y_test_tensor), batch_size=64, shuffle=True)
     
-    return train_loader, test_loader
+    return train_loader, val_loader, test_loader
 
 
 
diff --git a/src/utils_fed.py b/src/utils_fed.py
index 7544e235ae1785d13d1e6d6e318648dbea79ecf6..b1ee0091680c3381b6007a0d0faed0caeb4a771a 100644
--- a/src/utils_fed.py
+++ b/src/utils_fed.py
@@ -196,7 +196,7 @@ def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict,
         p_expert_opintion : Parameter to avoid completly random assignment if neeed (default to 0)
     """
     
-    from src.models import SimpleLinear
+    from src.models import GenericLinearModel, GenericConvModel
     import numpy as np
     import copy
 
@@ -212,7 +212,7 @@ def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict,
 
     my_server.num_clusters = row_exp['num_clusters']
 
-    my_server.clusters_models = {cluster_id: SimpleLinear(in_size=imgs_params[0], n_channels=imgs_params[1]) for cluster_id in range(row_exp['num_clusters'])}
+    my_server.clusters_models = {cluster_id: GenericConvModel(in_size=imgs_params[0], n_channels=imgs_params[1]) for cluster_id in range(row_exp['num_clusters'])}
     
     
     for client in list_clients:
@@ -291,8 +291,6 @@ def set_client_cluster(my_server : Server, list_clients : list, row_exp : dict)
         
         index_of_min_loss = np.argmin(cluster_losses)
         
-        #print(f"client {client.id} with heterogeneity {client.heterogeneity_class} cluster losses:", cluster_losses)
-
         client.model = copy.deepcopy(my_server.clusters_models[index_of_min_loss])
     
         client.cluster_id = index_of_min_loss
diff --git a/src/utils_training.py b/src/utils_training.py
index 1abd81d1feae0a0b52f2b6111d0a5c186cac53c3..4ae455d81ed1c0161a618bd442b507d28565c6f3 100644
--- a/src/utils_training.py
+++ b/src/utils_training.py
@@ -1,14 +1,13 @@
 import torch
 import torch.nn as nn
-import torch.optim as optim
+
 from torch.utils.data import DataLoader
 
 import pandas as pd
 
-from src.models import SimpleLinear
+from src.models import ImageClassificationBase
 from src.fedclass import Server
 
-lr = 0.01
 
 
 def run_cfl_server_side(model_server : Server, list_clients : list, row_exp : dict) -> pd.DataFrame:
@@ -21,7 +20,12 @@ def run_cfl_server_side(model_server : Server, list_clients : list, row_exp : di
         main_model : Type of Server model needed    
         list_clients : A list of Client Objects used as nodes in the FL protocol  
         row_exp : The current experiment's global parameters
+
+    Returns:
+
+        df_results : dataframe with the experiment results
     """
+
     from src.utils_fed import k_means_clustering
     import copy
     import torch 
@@ -29,24 +33,20 @@ def run_cfl_server_side(model_server : Server, list_clients : list, row_exp : di
     torch.manual_seed(row_exp['seed'])
 
     model_server = train_federated(model_server, list_clients, row_exp, use_cluster_models = False)
- 
-    model_server.clusters_models= {cluster_id: copy.deepcopy(model_server.model) for cluster_id in range(row_exp['num_clusters'])}
-  
+    model_server.clusters_models= {cluster_id: copy.deepcopy(model_server.model) for cluster_id in range(row_exp['num_clusters'])}  
     setattr(model_server, 'num_clusters', row_exp['num_clusters'])
 
     k_means_clustering(list_clients, row_exp['num_clusters'], row_exp['seed'])
-    
+
     model_server = train_federated(model_server, list_clients, row_exp, use_cluster_models = True)
 
     for client in list_clients :
 
-        acc = test_model(model_server.clusters_models[client.cluster_id], client.data_loader['test'])
-        
+        acc = test_model(model_server.clusters_models[client.cluster_id], client.data_loader['test'])    
         setattr(client, 'accuracy', acc)
 
-
     df_results = pd.DataFrame.from_records([c.to_dict() for c in list_clients])
-    
+
     return df_results 
 
 
@@ -72,16 +72,15 @@ def run_cfl_client_side(model_server : Server, list_clients : list, row_exp : di
 
         for client in list_clients:
 
-            client.model, _ = train_central(client.model, client.data_loader['train'], row_exp)
+            client.model, _ = train_central(client.model, client.data_loader['train'], client.data_loader['val'], row_exp)
 
         fedavg(model_server, list_clients)
-        
+
         set_client_cluster(model_server, list_clients, row_exp)
 
     for client in list_clients :
 
         acc = test_model(model_server.clusters_models[client.cluster_id], client.data_loader['test'])
-        
         setattr(client, 'accuracy', acc)
 
     df_results = pd.DataFrame.from_records([c.to_dict() for c in list_clients])
@@ -120,10 +119,8 @@ def run_benchmark(main_model : nn.Module, list_clients : list, row_exp : dict) -
             for heterogeneity_class in list_heterogeneities:
                 
                 list_clients_filtered = [client for client in list_clients if client.heterogeneity_class == heterogeneity_class]
-
-                train_loader, test_loader = centralize_data(list_clients_filtered)
-
-                model_trained, _ = train_central(curr_model, train_loader, row_exp) 
+                train_loader, val_loader, test_loader = centralize_data(list_clients_filtered)
+                model_trained, _ = train_central(curr_model, train_loader, val_loader, row_exp) 
 
                 global_acc = test_model(model_trained, test_loader) 
                      
@@ -134,11 +131,9 @@ def run_benchmark(main_model : nn.Module, list_clients : list, row_exp : dict) -
         case 'global-federated':
                 
             model_server = copy.deepcopy(curr_model)
-
             model_trained = train_federated(model_server, list_clients, row_exp, use_cluster_models = False)
-        
-            _, test_loader = centralize_data(list_clients)
 
+            _, test_loader = centralize_data(list_clients)
             global_acc = test_model(model_trained.model, test_loader) 
                      
             for client in list_clients : 
@@ -161,9 +156,8 @@ def train_federated(main_model, list_clients, row_exp, use_cluster_models = Fals
         row_exp: The current experiment's global parameters
         use_cluster_models: Boolean to determine whether to use personalization by clustering
     """
-
-    from src.utils_fed import send_server_model_to_client, send_cluster_models_to_clients, fedavg
     
+    from src.utils_fed import send_server_model_to_client, send_cluster_models_to_clients, fedavg
     
     for i in range(0, row_exp['federated_rounds']):
 
@@ -179,8 +173,7 @@ def train_federated(main_model, list_clients, row_exp, use_cluster_models = Fals
 
         for client in list_clients:
 
-            client.model, curr_acc = train_central(client.model, client.data_loader['train'], row_exp)
-
+            client.model, curr_acc = train_central(client.model, client.data_loader['train'], client.data_loader['val'], row_exp)
             accs.append(curr_acc)
 
         fedavg(main_model, list_clients)
@@ -188,54 +181,60 @@ def train_federated(main_model, list_clients, row_exp, use_cluster_models = Fals
     return main_model
 
 
-def train_central(main_model, train_loader, row_exp):
 
-    """ Main training function for centralized learning
+@torch.no_grad()
+def evaluate(model : nn.Module, val_loader : DataLoader) -> dict:
     
-    Arguments:
+    """ Returns a dict with loss and accuracy information"""
 
-        main_model : Server model used in our experiment
-        train_loader : DataLoader with the dataset to use for training
-        row_exp : The current experiment's global parameters
+    model.eval()
+    outputs = [model.validation_step(batch) for batch in val_loader]
+    return model.validation_epoch_end(outputs)
 
-    """
-    criterion = nn.CrossEntropyLoss()
-    
-    optimizer=optim.SGD if row_exp['nn_model'] == "linear" else optim.Adam
-    optimizer = optimizer(main_model.parameters(), lr=0.01) 
-   
-    main_model.train()
-    
-    for epoch in range(row_exp['centralized_epochs']):
-          
-        running_loss = total = correct = 0
 
-        for inputs, labels in train_loader:
 
-            optimizer.zero_grad()  
+def train_central(model : ImageClassificationBase, train_loader : DataLoader, val_loader : DataLoader, row_exp : dict):
 
-            outputs = main_model(inputs)  
+    """ Main training function for centralized learning
+    
+    Arguments:
+        model : Server model used in our experiment
+        train_loader : DataLoader with the training dataset
+        val_loader : Dataloader with the validation dataset
+        row_exp : The current experiment's global parameters
 
-            _, predicted = torch.max(outputs, 1)
+    Returns:
+        (model, history) : base model with trained weights / results at each training step
+    """
 
-            loss = criterion(outputs, labels)
+    opt_func=torch.optim.SGD #if row_exp['nn_model'] == "linear" else torch.optim.Adam
+    lr = 0.001
+    history = []
+    optimizer = opt_func(model.parameters(), lr)
+    
+    for epoch in range(row_exp['centralized_epochs']):
+        
+        model.train()
+        train_losses = []
+        
+        for batch in train_loader:
 
-            loss.backward() 
+            loss = model.training_step(batch)
+            train_losses.append(loss)
+            loss.backward()
 
             optimizer.step()
-
-            running_loss += loss.item() * inputs.size(0)
-            
-            total += labels.size(0)
-            
-            correct += (predicted == labels).sum().item()
-
-    accuracy = correct / total
-
-    main_model.eval() 
-
-    return main_model, accuracy
-
+            optimizer.zero_grad()
+        
+        result = evaluate(model, val_loader)
+        result['train_loss'] = torch.stack(train_losses).mean().item()        
+        
+        model.epoch_end(epoch, result)
+        
+        history.append(result)
+    
+    return model, history
+    
 
 def test_model(model : nn.Module, test_loader : DataLoader) -> float:
 
@@ -254,7 +253,6 @@ def test_model(model : nn.Module, test_loader : DataLoader) -> float:
     total = 0
     test_loss = 0.0
 
-
     with torch.no_grad():
 
         for inputs, labels in test_loader:
@@ -262,13 +260,11 @@ def test_model(model : nn.Module, test_loader : DataLoader) -> float:
             outputs = model(inputs)
 
             loss = criterion(outputs, labels)
-
             test_loss += loss.item() * inputs.size(0)
 
             _, predicted = torch.max(outputs, 1)
 
             total += labels.size(0)
-           
             correct += (predicted == labels).sum().item()
 
     test_loss = test_loss / len(test_loader.dataset)