From fe4bfa53e68c52f01e73cc282a3eb9ceb0174ed3 Mon Sep 17 00:00:00 2001 From: Omar El Rifai <omar.void@gmail.com> Date: Thu, 12 Sep 2024 11:12:34 +0200 Subject: [PATCH] Add notebook to repo --- draft.ipynb | 291 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 291 insertions(+) create mode 100644 draft.ipynb diff --git a/draft.ipynb b/draft.ipynb new file mode 100644 index 0000000..8349955 --- /dev/null +++ b/draft.ipynb @@ -0,0 +1,291 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "id": "c1649e65-6fb0-4af7-8ecd-d94f44511d9d", + "metadata": {}, + "outputs": [], + "source": [ + "import tarfile\n", + "import os \n", + "\n", + "\n", + "import numpy as np\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import Dataset, DataLoader, TensorDataset\n", + "from torch.utils.data import random_split\n", + "import torchvision\n", + "from torchvision.datasets.utils import download_url\n", + "from torchvision.datasets import ImageFolder\n", + "from torchvision.transforms import ToTensor\n", + "import torchvision.transforms as transforms\n", + "\n", + "\n", + "from src.utils_results import plot_img\n", + "from src.fedclass import Client\n", + "from src.utils_training import train_central, test_model\n", + "from src.utils_data import get_clients_data, data_preparation\n", + "from src.models import GenericConvModel\n", + "from sklearn.model_selection import train_test_split\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "59c7fe59-a1ce-4925-bdca-8ecb777902e8", + "metadata": {}, + "source": [ + "## Three Methods to load the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "60951c57-4f25-4e62-9255-a57d120c6370", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using downloaded and verified file: ./cifar10.tgz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_6044/2990241823.py:6: DeprecationWarning: Python 3.14 will, by default, filter extracted tar archives and reject files or modify their metadata. Use the filter argument to control this behavior.\n", + " tar.extractall(path='./data')\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Files already downloaded and verified\n", + "Files already downloaded and verified\n" + ] + } + ], + "source": [ + "### 1- Using raw image folder\n", + "\n", + "dataset_url = \"https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz\"\n", + "download_url(dataset_url, '.')\n", + "with tarfile.open('./cifar10.tgz', 'r:gz') as tar:\n", + " tar.extractall(path='./data')\n", + "data_dir = './data/cifar10'\n", + "\n", + "classes = os.listdir(data_dir + \"/train\")\n", + "dataset1 = ImageFolder(data_dir+'/train', transform=ToTensor())\n", + "\n", + "\n", + "\n", + "### 2- Using project functions\n", + "\n", + "dict_clients = get_clients_data(num_clients = 1, num_samples_by_label = 600, dataset = 'cifar10', nn_model = 'convolutional')\n", + "x_data, y_data = dict_clients[0]['x'], dict_clients[0]['y']\n", + "x_data = np.transpose(x_data, (0, 3, 1, 2))\n", + "\n", + "dataset2 = TensorDataset(torch.tensor(x_data, dtype=torch.float32), torch.tensor(y_data, dtype=torch.long))\n", + "\n", + "\n", + "\n", + "### 3 - Using CIFAR10 dataset from Pytorch\n", + "\n", + "cifar10 = torchvision.datasets.CIFAR10(\"datasets\", download=True, transform=ToTensor())\n", + "(x_data, y_data) = cifar10.data, cifar10.targets\n", + "x_data = np.transpose(x_data, (0, 3, 1, 2))\n", + "dataset3 = TensorDataset(torch.tensor(x_data, dtype=torch.float32), torch.tensor(y_data, dtype=torch.long))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "2ff13653-be89-4b0b-97f0-ffe7ee9c23ab", + "metadata": {}, + "outputs": [], + "source": [ + "model = GenericConvModel(32,3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "148d883c-a667-49a2-87c1-5962f1c859eb", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "c4f0dc1d-4cdc-47cb-b200-2c58984ac171", + "metadata": {}, + "source": [ + "## Conversion to dataloaders" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "bcefeb34-f9f4-4086-8af9-73469c3fd375", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "random_seed = 42\n", + "torch.manual_seed(random_seed);\n", + "val_size = 5000\n", + "train_size = len(dataset1) - val_size\n", + "\n", + "train_ds, val_ds = random_split(dataset1, [train_size, val_size])\n", + "\n", + "batch_size=128\n", + "train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)\n", + "val_dl = DataLoader(val_ds, batch_size*2, num_workers=4, pin_memory=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "91bec335-be85-4ef8-a27f-298ed08b80fc", + "metadata": {}, + "outputs": [], + "source": [ + "num_epochs = 10\n", + "opt_func = torch.optim.Adam\n", + "lr = 0.001\n", + "\n", + "@torch.no_grad()\n", + "def evaluate(model, val_loader):\n", + " model.eval()\n", + " outputs = [model.validation_step(batch) for batch in val_loader]\n", + " return model.validation_epoch_end(outputs)\n", + "\n", + "def fit(epochs, lr, model, train_loader, val_loader, opt_func=opt_func):\n", + " history = []\n", + " optimizer = opt_func(model.parameters(), lr)\n", + " for epoch in range(epochs):\n", + " # Training Phase \n", + " model.train()\n", + " train_losses = []\n", + " for batch in train_loader:\n", + " loss = model.training_step(batch)\n", + " train_losses.append(loss)\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " # Validation phase\n", + " result = evaluate(model, val_loader)\n", + " result['train_loss'] = torch.stack(train_losses).mean().item()\n", + " model.epoch_end(epoch, result)\n", + " history.append(result)\n", + " return history\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56aa9198-0a07-4ec8-802b-792352667795", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [0], train_loss: 1.7809, val_loss: 1.4422, val_acc: 0.4745\n", + "Epoch [1], train_loss: 1.2344, val_loss: 1.0952, val_acc: 0.6092\n", + "Epoch [2], train_loss: 0.9971, val_loss: 0.9526, val_acc: 0.6552\n", + "Epoch [3], train_loss: 0.8338, val_loss: 0.8339, val_acc: 0.7085\n", + "Epoch [4], train_loss: 0.7093, val_loss: 0.7892, val_acc: 0.7239\n", + "Epoch [5], train_loss: 0.6082, val_loss: 0.7572, val_acc: 0.7490\n" + ] + } + ], + "source": [ + "history = fit(num_epochs, lr, model, train_dl, val_dl, opt_func)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "a106673e-a9a9-4525-bc94-98d3b64f2a7d", + "metadata": {}, + "outputs": [], + "source": [ + "result = evaluate(model, test_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "24941b20-3aed-4336-9f79-87e4fcf0bba7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'val_loss': 2.3049447536468506, 'val_acc': 0.10572139918804169}" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "218e7550-2b48-4a8e-8547-e3afe81d34fe", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f616e4-e565-4396-9450-03c891530640", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3fbea674-afff-495a-ac28-c34b44561d47", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} -- GitLab