diff --git a/src/models.py b/src/models.py
index 77b903cc6b17296c66a8e0c7fb156d17b8f733de..8e95fc5ec9f8a18627fc1703e591abfd605f16d0 100644
--- a/src/models.py
+++ b/src/models.py
@@ -12,14 +12,14 @@ def accuracy(outputs, labels):
 class ImageClassificationBase(nn.Module):
     def training_step(self, batch, device):
         images, labels = batch
-        images, labels = images.to(device), labels.to(device) 
+        images, labels = images.to(device), labels.to(device).long() 
         out = self(images)
         loss = F.cross_entropy(out, labels)  
         return loss
     
     def validation_step(self, batch, device):
         images, labels = batch
-        images, labels = images.to(device), labels.to(device)  
+        images, labels = images.to(device), labels.to(device).long()
         out = self(images)
         loss = F.cross_entropy(out, labels) 
         acc = accuracy(out, labels)