diff --git a/src/utils_data.py b/src/utils_data.py index afa820cace1fa73ae6ee05e4c133494e191a108f..a5dc8a353aed0d220ac0a814f576984d8e8a79fa 100644 --- a/src/utils_data.py +++ b/src/utils_data.py @@ -250,7 +250,9 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: import torch list_clients = [] - + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + torch.manual_seed(row_exp['seed']) imgs_params = {'mnist': (28,1) , 'fashion-mnist': (28,1), 'kmnist': (28,1), 'cifar10': (32,3)} @@ -263,6 +265,8 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: model_server = Server(GenericConvModel(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) + model_server.model.to(device) + dict_clients = get_clients_data(row_exp['num_clients'], row_exp['num_samples_by_label'], row_exp['dataset'],