diff --git a/src/utils_training.py b/src/utils_training.py
index 44a7c1710d3f20da2174751ec557a931c86ba54c..d1cc6d03d6e94316f29a4538d7e2bbc6cfda779f 100644
--- a/src/utils_training.py
+++ b/src/utils_training.py
@@ -177,14 +177,21 @@ def train_federated(main_model, list_clients, row_exp, use_cluster_models = Fals
     return main_model
 
 
-
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 @torch.no_grad()
 def evaluate(model : nn.Module, val_loader : DataLoader) -> dict:
     
     """ Returns a dict with loss and accuracy information"""
-
+    model.to(device)
     model.eval()
-    outputs = [model.validation_step(batch) for batch in val_loader]
+    outputs =[]
+    for batch in val_loader:
+        # Move entire batch to the correct device
+        batch = [item.to(device) for item in batch]
+        
+        # Call the validation step and append to outputs
+        output = model.validation_step(batch,device)
+        outputs.append(output)
     return model.validation_epoch_end(outputs)
 
 def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_loader: DataLoader, row_exp: dict):
@@ -201,13 +208,12 @@ def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_
     """
 
     # Check if CUDA is available and set the device
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     
     # Move the model to the appropriate device
     model.to(device)
 
-    opt_func = torch.optim.SGD  # if row_exp['nn_model'] == "linear" else torch.optim.Adam
-    lr = 0.01
+    opt_func = torch.optim.Adam  # if row_exp['nn_model'] == "linear" else torch.optim.Adam
+    lr = 0.001
     history = []
     optimizer = opt_func(model.parameters(), lr)
 
@@ -218,15 +224,16 @@ def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_
         
         for batch in train_loader:
             # Move batch to the same device as the model
-            batch = [item.to(device) for item in batch]  # Assuming batch is a tuple of tensors
-            
-            loss = model.training_step(batch)
+            inputs, labels = [item.to(device) for item in batch]
+    
+            # Pass the unpacked inputs and labels to the model's training step
+            loss = model.training_step((inputs, labels),device)            
             train_losses.append(loss)
             loss.backward()
 
             optimizer.step()
             optimizer.zero_grad()
-        
+                
         result = evaluate(model, val_loader)  # Ensure evaluate handles CUDA as needed
         result['train_loss'] = torch.stack(train_losses).mean().item()        
         
@@ -262,12 +269,8 @@ def test_model(model: nn.Module, test_loader: DataLoader) -> float:
 
     with torch.no_grad():  # No need to track gradients in evaluation
 
-        for inputs, labels in test_loader:
-            
-            # Move inputs and labels to the device
-            inputs, labels = inputs.to(device), labels.to(device)
-            
-            # Forward pass
+        for batch in test_loader:
+            inputs, labels = [item.to(device) for item in batch]
             outputs = model(inputs)
 
             # Compute the loss