diff --git a/src/utils_fed.py b/src/utils_fed.py index 345bf7c0df90b9f9dca50046c07ba3f1c9555d9e..375b2af2ff69be880a9c3fc31879f3819a866718 100644 --- a/src/utils_fed.py +++ b/src/utils_fed.py @@ -252,7 +252,7 @@ def loss_calculation(model : nn.modules, train_loader : DataLoader) -> float: with torch.no_grad(): for inputs, targets in train_loader: - + inputs, targets = inputs.to(device), targets.to(device).long() outputs = model(inputs) loss = criterion(outputs, targets)