From 370c2552081f7808acb8328bed5f97d87742eaff Mon Sep 17 00:00:00 2001
From: Omar El Rifai <omar.void@gmail.com>
Date: Fri, 13 Sep 2024 14:01:52 +0200
Subject: [PATCH] Add gpu capability

---
 src/utils_data.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/src/utils_data.py b/src/utils_data.py
index afa820c..a5dc8a3 100644
--- a/src/utils_data.py
+++ b/src/utils_data.py
@@ -250,7 +250,9 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
     import torch
     
     list_clients = []
-    
+
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
     torch.manual_seed(row_exp['seed'])
 
     imgs_params = {'mnist': (28,1) , 'fashion-mnist': (28,1), 'kmnist': (28,1), 'cifar10': (32,3)}
@@ -263,6 +265,8 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]:
         
         model_server = Server(GenericConvModel(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1]))
 
+    model_server.model.to(device)
+
     dict_clients = get_clients_data(row_exp['num_clients'],
                                     row_exp['num_samples_by_label'],
                                     row_exp['dataset'],
-- 
GitLab