From b8826a5a27c7380ef8f40f0573930b75103de019 Mon Sep 17 00:00:00 2001 From: leahcimali <michaelbenalipro@gmail.com> Date: Mon, 7 Oct 2024 16:10:45 +0200 Subject: [PATCH] Correct loss calculation for client side --- src/utils_fed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils_fed.py b/src/utils_fed.py index 345bf7c..375b2af 100644 --- a/src/utils_fed.py +++ b/src/utils_fed.py @@ -252,7 +252,7 @@ def loss_calculation(model : nn.modules, train_loader : DataLoader) -> float: with torch.no_grad(): for inputs, targets in train_loader: - + inputs, targets = inputs.to(device), targets.to(device).long() outputs = model(inputs) loss = criterion(outputs, targets) -- GitLab