Skip to content
Snippets Groups Projects
Commit b8826a5a authored by leahcimali's avatar leahcimali
Browse files

Correct loss calculation for client side

parent 73487f6e
No related branches found
No related tags found
No related merge requests found
......@@ -252,7 +252,7 @@ def loss_calculation(model : nn.modules, train_loader : DataLoader) -> float:
with torch.no_grad():
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device).long()
outputs = model(inputs)
loss = criterion(outputs, targets)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment