diff --git a/src/utils_data.py b/src/utils_data.py index 838586092434938865f6c91de08b3b3d8449d6b5..1a0058d33746eb509b0bd323878436af448acb3b 100644 --- a/src/utils_data.py +++ b/src/utils_data.py @@ -518,9 +518,8 @@ def apply_quantity_skew(list_clients : list, row_exp : dict, list_skews : list) dict_clients = [get_clients_data(n_clients_by_skew, int(n_max_samples * skew), row_exp['dataset'], - seed=row_exp['seed']) + row_exp['nn_model']) for skew in list_skews] - list_clients = [] for c in range(n_clients_by_skew): diff --git a/src/utils_training.py b/src/utils_training.py index ca2e523a3553279c11a0f244402c83a2929e9dc8..d1cc6d03d6e94316f29a4538d7e2bbc6cfda779f 100644 --- a/src/utils_training.py +++ b/src/utils_training.py @@ -212,7 +212,7 @@ def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_ # Move the model to the appropriate device model.to(device) - opt_func = torch.optim.SGD # if row_exp['nn_model'] == "linear" else torch.optim.Adam + opt_func = torch.optim.Adam # if row_exp['nn_model'] == "linear" else torch.optim.Adam lr = 0.001 history = [] optimizer = opt_func(model.parameters(), lr)