diff --git a/src/utils_fed.py b/src/utils_fed.py
index 345bf7c0df90b9f9dca50046c07ba3f1c9555d9e..375b2af2ff69be880a9c3fc31879f3819a866718 100644
--- a/src/utils_fed.py
+++ b/src/utils_fed.py
@@ -252,7 +252,7 @@ def loss_calculation(model : nn.modules, train_loader : DataLoader) -> float:
     with torch.no_grad():
 
         for inputs, targets in train_loader:
-
+            inputs, targets = inputs.to(device), targets.to(device).long()
             outputs = model(inputs)
 
             loss = criterion(outputs, targets)