From dab33070234188f33862fa00cba1238279b10044 Mon Sep 17 00:00:00 2001 From: Leahcimali <michaelbenalipro@gmail.com> Date: Fri, 11 Oct 2024 10:25:03 +0200 Subject: [PATCH] Correct problem with Quantity Skew --- src/utils_data.py | 3 +-- src/utils_training.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/utils_data.py b/src/utils_data.py index 8385860..1a0058d 100644 --- a/src/utils_data.py +++ b/src/utils_data.py @@ -518,9 +518,8 @@ def apply_quantity_skew(list_clients : list, row_exp : dict, list_skews : list) dict_clients = [get_clients_data(n_clients_by_skew, int(n_max_samples * skew), row_exp['dataset'], - seed=row_exp['seed']) + row_exp['nn_model']) for skew in list_skews] - list_clients = [] for c in range(n_clients_by_skew): diff --git a/src/utils_training.py b/src/utils_training.py index ca2e523..d1cc6d0 100644 --- a/src/utils_training.py +++ b/src/utils_training.py @@ -212,7 +212,7 @@ def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_ # Move the model to the appropriate device model.to(device) - opt_func = torch.optim.SGD # if row_exp['nn_model'] == "linear" else torch.optim.Adam + opt_func = torch.optim.Adam # if row_exp['nn_model'] == "linear" else torch.optim.Adam lr = 0.001 history = [] optimizer = opt_func(model.parameters(), lr) -- GitLab