Skip to content
Snippets Groups Projects
Commit 616b3166 authored by Leahcimali's avatar Leahcimali
Browse files

Convert data to PIL image for cifar10

parent c34194aa
No related branches found
No related tags found
No related merge requests found
...@@ -55,27 +55,26 @@ def create_label_dict(dataset : str, nn_model : str) -> dict: ...@@ -55,27 +55,26 @@ def create_label_dict(dataset : str, nn_model : str) -> dict:
fashion_mnist = torchvision.datasets.FashionMNIST("datasets", download=True) fashion_mnist = torchvision.datasets.FashionMNIST("datasets", download=True)
(x_data, y_data) = fashion_mnist.data, fashion_mnist.targets (x_data, y_data) = fashion_mnist.data, fashion_mnist.targets
if nn_model == "convolutional": if nn_model == "convolutional":
x_data = x_data.unsqueeze(1) # Change shape to (samples, 1, H, W) x_data = x_data.unsqueeze(3) # Change shape to (samples, 1, H, W)
elif dataset == 'mnist': elif dataset == 'mnist':
mnist = torchvision.datasets.MNIST("datasets", download=True) mnist = torchvision.datasets.MNIST("datasets", download=True)
(x_data, y_data) = mnist.data, mnist.targets (x_data, y_data) = mnist.data, mnist.targets
if nn_model == "convolutional": if nn_model == "convolutional":
x_data = x_data.unsqueeze(1) # Change shape to (samples, 1, H, W) x_data = x_data.unsqueeze(3) # Change shape to (samples, 1, H, W)
elif dataset == 'kmnist': elif dataset == 'kmnist':
kmnist = torchvision.datasets.KMNIST("datasets", download=True) kmnist = torchvision.datasets.KMNIST("datasets", download=True)
x_data, y_data = kmnist.data, kmnist.targets x_data, y_data = kmnist.data, kmnist.targets
if nn_model == "convolutional": if nn_model == "convolutional":
x_data = x_data.unsqueeze(1) # Change shape to (samples, 1, H, W) x_data = x_data.unsqueeze(3) # Change shape to (samples, 1, H, W)
elif dataset == "cifar10": elif dataset == "cifar10":
if nn_model == "linear": if nn_model == "linear":
raise ValueError("CIFAR-10 cannot be used with a linear model. Please use a convolutional model.") raise ValueError("CIFAR-10 cannot be used with a linear model. Please use a convolutional model.")
cifar10 = torchvision.datasets.CIFAR10("datasets", download=True) cifar10 = torchvision.datasets.CIFAR10("datasets", download=True)
x_data, y_data = cifar10.data, cifar10.targets x_data, y_data = cifar10.data, cifar10.targets # (samples, H, W, C)
x_data = np.transpose(x_data, (0, 3, 1, 2)) # Change shape to (samples, C, H, W)
else: else:
sys.exit("Unrecognized dataset. Please make sure you are using one of the following ['mnist', fashion-mnist', 'kmnist']") sys.exit("Unrecognized dataset. Please make sure you are using one of the following ['mnist', fashion-mnist', 'kmnist']")
...@@ -177,7 +176,8 @@ class CustomDataset(Dataset): ...@@ -177,7 +176,8 @@ class CustomDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
sample = self.x_data[idx] sample = self.x_data[idx]
label = self.y_data[idx] label = self.y_data[idx]
#if sample.shape[0] == 3: # This implies CIFAR-10's RGB data
# sample = transforms.ToPILImage()(sample)
if self.transform: if self.transform:
sample = self.transform(sample) sample = self.transform(sample)
...@@ -199,6 +199,7 @@ def data_transformation(row_exp : dict)-> tuple: ...@@ -199,6 +199,7 @@ def data_transformation(row_exp : dict)-> tuple:
''' '''
if row_exp['dataset'] == 'cifar10': if row_exp['dataset'] == 'cifar10':
train_transform = transforms.Compose([ train_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.RandomRotation(20), transforms.RandomRotation(20),
transforms.RandomCrop(32, padding=4), transforms.RandomCrop(32, padding=4),
...@@ -206,10 +207,12 @@ def data_transformation(row_exp : dict)-> tuple: ...@@ -206,10 +207,12 @@ def data_transformation(row_exp : dict)-> tuple:
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]) ])
val_transform = transforms.Compose([ val_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]) ])
test_transform = transforms.Compose([ test_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]) ])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment