diff --git a/src/utils_data.py b/src/utils_data.py
index 8d093bb1704b685f76a886605c22cf0cf5251cd6..c678b7daf83b790134600b825bf182bcbb383a7b 100644
--- a/src/utils_data.py
+++ b/src/utils_data.py
@@ -55,27 +55,26 @@ def create_label_dict(dataset : str, nn_model : str) -> dict:
         fashion_mnist = torchvision.datasets.FashionMNIST("datasets", download=True)
         (x_data, y_data) = fashion_mnist.data, fashion_mnist.targets
         if nn_model == "convolutional":
-            x_data = x_data.unsqueeze(1) # Change shape to (samples, 1, H, W)
+            x_data = x_data.unsqueeze(3) # Change shape to (samples, 1, H, W)
         
     elif dataset == 'mnist':
         mnist = torchvision.datasets.MNIST("datasets", download=True)
         (x_data, y_data) = mnist.data, mnist.targets
         if nn_model == "convolutional":
-            x_data = x_data.unsqueeze(1) # Change shape to (samples, 1, H, W)
+            x_data = x_data.unsqueeze(3) # Change shape to (samples, 1, H, W)
     
     elif dataset == 'kmnist':
         kmnist = torchvision.datasets.KMNIST("datasets", download=True)
         x_data, y_data = kmnist.data, kmnist.targets
         if nn_model == "convolutional":
-            x_data = x_data.unsqueeze(1) # Change shape to (samples, 1, H, W)
+            x_data = x_data.unsqueeze(3) # Change shape to (samples, 1, H, W)
             
     elif dataset == "cifar10":
         if nn_model == "linear":
             raise ValueError("CIFAR-10 cannot be used with a linear model. Please use a convolutional model.")
 
         cifar10 = torchvision.datasets.CIFAR10("datasets", download=True)
-        x_data, y_data = cifar10.data, cifar10.targets
-        x_data = np.transpose(x_data, (0, 3, 1, 2))  # Change shape to (samples, C, H, W)
+        x_data, y_data = cifar10.data, cifar10.targets # (samples, H, W, C)
     else:
         sys.exit("Unrecognized dataset. Please make sure you are using one of the following ['mnist', fashion-mnist', 'kmnist']")    
 
@@ -177,7 +176,8 @@ class CustomDataset(Dataset):
     def __getitem__(self, idx):
         sample = self.x_data[idx]
         label = self.y_data[idx]
-
+        #if sample.shape[0] == 3:  # This implies CIFAR-10's RGB data
+        #    sample = transforms.ToPILImage()(sample)
         if self.transform:
             sample = self.transform(sample)
 
@@ -199,6 +199,7 @@ def data_transformation(row_exp : dict)-> tuple:
     '''
     if row_exp['dataset'] == 'cifar10': 
         train_transform = transforms.Compose([
+            transforms.ToPILImage(),
             transforms.RandomHorizontalFlip(),
             transforms.RandomRotation(20),
             transforms.RandomCrop(32, padding=4),
@@ -206,10 +207,12 @@ def data_transformation(row_exp : dict)-> tuple:
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
         ])
         val_transform = transforms.Compose([
+        transforms.ToPILImage(),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
         ])
         test_transform = transforms.Compose([
+        transforms.ToPILImage(),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
         ])