From 370c2552081f7808acb8328bed5f97d87742eaff Mon Sep 17 00:00:00 2001 From: Omar El Rifai <omar.void@gmail.com> Date: Fri, 13 Sep 2024 14:01:52 +0200 Subject: [PATCH] Add gpu capability --- src/utils_data.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/utils_data.py b/src/utils_data.py index afa820c..a5dc8a3 100644 --- a/src/utils_data.py +++ b/src/utils_data.py @@ -250,7 +250,9 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: import torch list_clients = [] - + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + torch.manual_seed(row_exp['seed']) imgs_params = {'mnist': (28,1) , 'fashion-mnist': (28,1), 'kmnist': (28,1), 'cifar10': (32,3)} @@ -263,6 +265,8 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: model_server = Server(GenericConvModel(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) + model_server.model.to(device) + dict_clients = get_clients_data(row_exp['num_clients'], row_exp['num_samples_by_label'], row_exp['dataset'], -- GitLab