diff --git a/src/utils_data.py b/src/utils_data.py index 8d093bb1704b685f76a886605c22cf0cf5251cd6..c678b7daf83b790134600b825bf182bcbb383a7b 100644 --- a/src/utils_data.py +++ b/src/utils_data.py @@ -55,27 +55,26 @@ def create_label_dict(dataset : str, nn_model : str) -> dict: fashion_mnist = torchvision.datasets.FashionMNIST("datasets", download=True) (x_data, y_data) = fashion_mnist.data, fashion_mnist.targets if nn_model == "convolutional": - x_data = x_data.unsqueeze(1) # Change shape to (samples, 1, H, W) + x_data = x_data.unsqueeze(3) # Change shape to (samples, 1, H, W) elif dataset == 'mnist': mnist = torchvision.datasets.MNIST("datasets", download=True) (x_data, y_data) = mnist.data, mnist.targets if nn_model == "convolutional": - x_data = x_data.unsqueeze(1) # Change shape to (samples, 1, H, W) + x_data = x_data.unsqueeze(3) # Change shape to (samples, 1, H, W) elif dataset == 'kmnist': kmnist = torchvision.datasets.KMNIST("datasets", download=True) x_data, y_data = kmnist.data, kmnist.targets if nn_model == "convolutional": - x_data = x_data.unsqueeze(1) # Change shape to (samples, 1, H, W) + x_data = x_data.unsqueeze(3) # Change shape to (samples, 1, H, W) elif dataset == "cifar10": if nn_model == "linear": raise ValueError("CIFAR-10 cannot be used with a linear model. Please use a convolutional model.") cifar10 = torchvision.datasets.CIFAR10("datasets", download=True) - x_data, y_data = cifar10.data, cifar10.targets - x_data = np.transpose(x_data, (0, 3, 1, 2)) # Change shape to (samples, C, H, W) + x_data, y_data = cifar10.data, cifar10.targets # (samples, H, W, C) else: sys.exit("Unrecognized dataset. Please make sure you are using one of the following ['mnist', fashion-mnist', 'kmnist']") @@ -177,7 +176,8 @@ class CustomDataset(Dataset): def __getitem__(self, idx): sample = self.x_data[idx] label = self.y_data[idx] - + #if sample.shape[0] == 3: # This implies CIFAR-10's RGB data + # sample = transforms.ToPILImage()(sample) if self.transform: sample = self.transform(sample) @@ -199,6 +199,7 @@ def data_transformation(row_exp : dict)-> tuple: ''' if row_exp['dataset'] == 'cifar10': train_transform = transforms.Compose([ + transforms.ToPILImage(), transforms.RandomHorizontalFlip(), transforms.RandomRotation(20), transforms.RandomCrop(32, padding=4), @@ -206,10 +207,12 @@ def data_transformation(row_exp : dict)-> tuple: transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) val_transform = transforms.Compose([ + transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) test_transform = transforms.Compose([ + transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])