diff --git a/src/models.py b/src/models.py index 77b903cc6b17296c66a8e0c7fb156d17b8f733de..8e95fc5ec9f8a18627fc1703e591abfd605f16d0 100644 --- a/src/models.py +++ b/src/models.py @@ -12,14 +12,14 @@ def accuracy(outputs, labels): class ImageClassificationBase(nn.Module): def training_step(self, batch, device): images, labels = batch - images, labels = images.to(device), labels.to(device) + images, labels = images.to(device), labels.to(device).long() out = self(images) loss = F.cross_entropy(out, labels) return loss def validation_step(self, batch, device): images, labels = batch - images, labels = images.to(device), labels.to(device) + images, labels = images.to(device), labels.to(device).long() out = self(images) loss = F.cross_entropy(out, labels) acc = accuracy(out, labels)