From 73487f6ef207a61a2a51569de099f2e81a5f2a73 Mon Sep 17 00:00:00 2001
From: leahcimali <michaelbenalipro@gmail.com>
Date: Mon, 7 Oct 2024 15:50:34 +0200
Subject: [PATCH] Update model. py label for int to long tensor

---
 src/models.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/models.py b/src/models.py
index 77b903c..8e95fc5 100644
--- a/src/models.py
+++ b/src/models.py
@@ -12,14 +12,14 @@ def accuracy(outputs, labels):
 class ImageClassificationBase(nn.Module):
     def training_step(self, batch, device):
         images, labels = batch
-        images, labels = images.to(device), labels.to(device) 
+        images, labels = images.to(device), labels.to(device).long() 
         out = self(images)
         loss = F.cross_entropy(out, labels)  
         return loss
     
     def validation_step(self, batch, device):
         images, labels = batch
-        images, labels = images.to(device), labels.to(device)  
+        images, labels = images.to(device), labels.to(device).long()
         out = self(images)
         loss = F.cross_entropy(out, labels) 
         acc = accuracy(out, labels)
-- 
GitLab