Skip to content
Snippets Groups Projects
Commit dab33070 authored by Leahcimali's avatar Leahcimali
Browse files

Correct problem with Quantity Skew

parent 746520a8
No related branches found
No related tags found
No related merge requests found
...@@ -518,9 +518,8 @@ def apply_quantity_skew(list_clients : list, row_exp : dict, list_skews : list) ...@@ -518,9 +518,8 @@ def apply_quantity_skew(list_clients : list, row_exp : dict, list_skews : list)
dict_clients = [get_clients_data(n_clients_by_skew, dict_clients = [get_clients_data(n_clients_by_skew,
int(n_max_samples * skew), int(n_max_samples * skew),
row_exp['dataset'], row_exp['dataset'],
seed=row_exp['seed']) row_exp['nn_model'])
for skew in list_skews] for skew in list_skews]
list_clients = [] list_clients = []
for c in range(n_clients_by_skew): for c in range(n_clients_by_skew):
......
...@@ -212,7 +212,7 @@ def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_ ...@@ -212,7 +212,7 @@ def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_
# Move the model to the appropriate device # Move the model to the appropriate device
model.to(device) model.to(device)
opt_func = torch.optim.SGD # if row_exp['nn_model'] == "linear" else torch.optim.Adam opt_func = torch.optim.Adam # if row_exp['nn_model'] == "linear" else torch.optim.Adam
lr = 0.001 lr = 0.001
history = [] history = []
optimizer = opt_func(model.parameters(), lr) optimizer = opt_func(model.parameters(), lr)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment