Skip to content
Snippets Groups Projects
Commit fe4bfa53 authored by Omar El Rifai's avatar Omar El Rifai
Browse files

Add notebook to repo

parent 2408dc63
Branches
No related tags found
No related merge requests found
%% Cell type:code id:c1649e65-6fb0-4af7-8ecd-d94f44511d9d tags:
``` python
import tarfile
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.utils.data import random_split
import torchvision
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
from src.utils_results import plot_img
from src.fedclass import Client
from src.utils_training import train_central, test_model
from src.utils_data import get_clients_data, data_preparation
from src.models import GenericConvModel
from sklearn.model_selection import train_test_split
```
%% Cell type:markdown id:59c7fe59-a1ce-4925-bdca-8ecb777902e8 tags:
## Three Methods to load the dataset
%% Cell type:code id:60951c57-4f25-4e62-9255-a57d120c6370 tags:
``` python
### 1- Using raw image folder
dataset_url = "https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz"
download_url(dataset_url, '.')
with tarfile.open('./cifar10.tgz', 'r:gz') as tar:
tar.extractall(path='./data')
data_dir = './data/cifar10'
classes = os.listdir(data_dir + "/train")
dataset1 = ImageFolder(data_dir+'/train', transform=ToTensor())
### 2- Using project functions
dict_clients = get_clients_data(num_clients = 1, num_samples_by_label = 600, dataset = 'cifar10', nn_model = 'convolutional')
x_data, y_data = dict_clients[0]['x'], dict_clients[0]['y']
x_data = np.transpose(x_data, (0, 3, 1, 2))
dataset2 = TensorDataset(torch.tensor(x_data, dtype=torch.float32), torch.tensor(y_data, dtype=torch.long))
### 3 - Using CIFAR10 dataset from Pytorch
cifar10 = torchvision.datasets.CIFAR10("datasets", download=True, transform=ToTensor())
(x_data, y_data) = cifar10.data, cifar10.targets
x_data = np.transpose(x_data, (0, 3, 1, 2))
dataset3 = TensorDataset(torch.tensor(x_data, dtype=torch.float32), torch.tensor(y_data, dtype=torch.long))
```
%% Output
Using downloaded and verified file: ./cifar10.tgz
/tmp/ipykernel_6044/2990241823.py:6: DeprecationWarning: Python 3.14 will, by default, filter extracted tar archives and reject files or modify their metadata. Use the filter argument to control this behavior.
tar.extractall(path='./data')
Files already downloaded and verified
Files already downloaded and verified
%% Cell type:code id:2ff13653-be89-4b0b-97f0-ffe7ee9c23ab tags:
``` python
model = GenericConvModel(32,3)
```
%% Cell type:code id:148d883c-a667-49a2-87c1-5962f1c859eb tags:
``` python
```
%% Cell type:markdown id:c4f0dc1d-4cdc-47cb-b200-2c58984ac171 tags:
## Conversion to dataloaders
%% Cell type:code id:bcefeb34-f9f4-4086-8af9-73469c3fd375 tags:
``` python
random_seed = 42
torch.manual_seed(random_seed);
val_size = 5000
train_size = len(dataset1) - val_size
train_ds, val_ds = random_split(dataset1, [train_size, val_size])
batch_size=128
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_dl = DataLoader(val_ds, batch_size*2, num_workers=4, pin_memory=True)
```
%% Cell type:code id:91bec335-be85-4ef8-a27f-298ed08b80fc tags:
``` python
num_epochs = 10
opt_func = torch.optim.Adam
lr = 0.001
@torch.no_grad()
def evaluate(model, val_loader):
model.eval()
outputs = [model.validation_step(batch) for batch in val_loader]
return model.validation_epoch_end(outputs)
def fit(epochs, lr, model, train_loader, val_loader, opt_func=opt_func):
history = []
optimizer = opt_func(model.parameters(), lr)
for epoch in range(epochs):
# Training Phase
model.train()
train_losses = []
for batch in train_loader:
loss = model.training_step(batch)
train_losses.append(loss)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Validation phase
result = evaluate(model, val_loader)
result['train_loss'] = torch.stack(train_losses).mean().item()
model.epoch_end(epoch, result)
history.append(result)
return history
```
%% Cell type:code id:56aa9198-0a07-4ec8-802b-792352667795 tags:
``` python
history = fit(num_epochs, lr, model, train_dl, val_dl, opt_func)
```
%% Output
Epoch [0], train_loss: 1.7809, val_loss: 1.4422, val_acc: 0.4745
Epoch [1], train_loss: 1.2344, val_loss: 1.0952, val_acc: 0.6092
Epoch [2], train_loss: 0.9971, val_loss: 0.9526, val_acc: 0.6552
Epoch [3], train_loss: 0.8338, val_loss: 0.8339, val_acc: 0.7085
Epoch [4], train_loss: 0.7093, val_loss: 0.7892, val_acc: 0.7239
Epoch [5], train_loss: 0.6082, val_loss: 0.7572, val_acc: 0.7490
%% Cell type:code id:a106673e-a9a9-4525-bc94-98d3b64f2a7d tags:
``` python
result = evaluate(model, test_loader)
```
%% Cell type:code id:24941b20-3aed-4336-9f79-87e4fcf0bba7 tags:
``` python
result
```
%% Output
{'val_loss': 2.3049447536468506, 'val_acc': 0.10572139918804169}
%% Cell type:code id:218e7550-2b48-4a8e-8547-e3afe81d34fe tags:
``` python
```
%% Cell type:code id:15f616e4-e565-4396-9450-03c891530640 tags:
``` python
```
%% Cell type:code id:3fbea674-afff-495a-ac28-c34b44561d47 tags:
``` python
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment