From b8826a5a27c7380ef8f40f0573930b75103de019 Mon Sep 17 00:00:00 2001
From: leahcimali <michaelbenalipro@gmail.com>
Date: Mon, 7 Oct 2024 16:10:45 +0200
Subject: [PATCH] Correct loss calculation for client side

---
 src/utils_fed.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/utils_fed.py b/src/utils_fed.py
index 345bf7c..375b2af 100644
--- a/src/utils_fed.py
+++ b/src/utils_fed.py
@@ -252,7 +252,7 @@ def loss_calculation(model : nn.modules, train_loader : DataLoader) -> float:
     with torch.no_grad():
 
         for inputs, targets in train_loader:
-
+            inputs, targets = inputs.to(device), targets.to(device).long()
             outputs = model(inputs)
 
             loss = criterion(outputs, targets)
-- 
GitLab