From 73487f6ef207a61a2a51569de099f2e81a5f2a73 Mon Sep 17 00:00:00 2001 From: leahcimali <michaelbenalipro@gmail.com> Date: Mon, 7 Oct 2024 15:50:34 +0200 Subject: [PATCH] Update model. py label for int to long tensor --- src/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/models.py b/src/models.py index 77b903c..8e95fc5 100644 --- a/src/models.py +++ b/src/models.py @@ -12,14 +12,14 @@ def accuracy(outputs, labels): class ImageClassificationBase(nn.Module): def training_step(self, batch, device): images, labels = batch - images, labels = images.to(device), labels.to(device) + images, labels = images.to(device), labels.to(device).long() out = self(images) loss = F.cross_entropy(out, labels) return loss def validation_step(self, batch, device): images, labels = batch - images, labels = images.to(device), labels.to(device) + images, labels = images.to(device), labels.to(device).long() out = self(images) loss = F.cross_entropy(out, labels) acc = accuracy(out, labels) -- GitLab