From 83348f60248cda479cf4e1b938c436ea13837224 Mon Sep 17 00:00:00 2001
From: leahcimali <michaelbenalipro@gmail.com>
Date: Tue, 1 Oct 2024 16:40:42 +0200
Subject: [PATCH] Update run_exp and add CUDA support for training

---
 run_exp.py            | 34 ++++++++++++++++
 src/utils_training.py | 92 +++++++++++++++++++++++++------------------
 2 files changed, 88 insertions(+), 38 deletions(-)
 create mode 100644 run_exp.py

diff --git a/run_exp.py b/run_exp.py
new file mode 100644
index 0000000..cd988a0
--- /dev/null
+++ b/run_exp.py
@@ -0,0 +1,34 @@
+import csv
+import subprocess
+
+# Path to your CSV file
+csv_file = "exp_configs.csv"
+
+# Read the second line from the CSV file
+with open(csv_file, newline='') as csvfile:
+    reader = csv.reader(csvfile)
+    
+    # Skip the header (if any) and the first row
+    next(reader)  # Skipping the header
+    row = next(reader)  # Reading the second row
+
+    # Assigning CSV values to variables
+    exp_type, dataset, nn_model, heterogeneity_type, num_clients, num_samples_by_label, num_clusters, centralized_epochs, federated_rounds, seed = row
+
+    # Building the command
+    command = [
+        "python", "driver.py",
+        "--exp_type", exp_type,
+        "--dataset", dataset,
+        "--nn_model", nn_model,
+        "--heterogeneity_type", heterogeneity_type,
+        "--num_clients", num_clients,
+        "--num_samples_by_label", num_samples_by_label,
+        "--num_clusters", num_clusters,
+        "--centralized_epochs", centralized_epochs,
+        "--federated_rounds", federated_rounds,
+        "--seed", seed
+    ]
+
+    # Run the command
+    subprocess.run(command)
diff --git a/src/utils_training.py b/src/utils_training.py
index 3805c6e..44a7c17 100644
--- a/src/utils_training.py
+++ b/src/utils_training.py
@@ -112,33 +112,29 @@ def run_benchmark(main_model : nn.Module, list_clients : list, row_exp : dict) -
 
     curr_model = main_model if row_exp['exp_type'] == 'global-federated' else main_model.model
 
-    match row_exp['exp_type']:
+    if row_exp['exp_type'] == 'pers-centralized':
+        for heterogeneity_class in list_heterogeneities:
+            list_clients_filtered = [client for client in list_clients if client.heterogeneity_class == heterogeneity_class]
+            train_loader, val_loader, test_loader = centralize_data(list_clients_filtered)
+            model_trained, _ = train_central(curr_model, train_loader, val_loader, row_exp) 
+
+            global_acc = test_model(model_trained, test_loader) 
+                    
+            for client in list_clients_filtered : 
     
-        case 'pers-centralized':
-
-            for heterogeneity_class in list_heterogeneities:
-                
-                list_clients_filtered = [client for client in list_clients if client.heterogeneity_class == heterogeneity_class]
-                train_loader, val_loader, test_loader = centralize_data(list_clients_filtered)
-                model_trained, _ = train_central(curr_model, train_loader, val_loader, row_exp) 
-
-                global_acc = test_model(model_trained, test_loader) 
-                     
-                for client in list_clients_filtered : 
-        
-                    setattr(client, 'accuracy', global_acc)
+                setattr(client, 'accuracy', global_acc)
     
-        case 'global-federated':
+    elif row_exp['exp_type'] == 'global-federated':
                 
-            model_server = copy.deepcopy(curr_model)
-            model_trained = train_federated(model_server, list_clients, row_exp, use_cluster_models = False)
+        model_server = copy.deepcopy(curr_model)
+        model_trained = train_federated(model_server, list_clients, row_exp, use_cluster_models = False)
 
-            _, _,test_loader = centralize_data(list_clients)
-            global_acc = test_model(model_trained.model, test_loader) 
-                     
-            for client in list_clients : 
-        
-                setattr(client, 'accuracy', global_acc)
+        _, _,test_loader = centralize_data(list_clients)
+        global_acc = test_model(model_trained.model, test_loader) 
+                    
+        for client in list_clients : 
+    
+            setattr(client, 'accuracy', global_acc)
 
     df_results = pd.DataFrame.from_records([c.to_dict() for c in list_clients])
     
@@ -191,10 +187,7 @@ def evaluate(model : nn.Module, val_loader : DataLoader) -> dict:
     outputs = [model.validation_step(batch) for batch in val_loader]
     return model.validation_epoch_end(outputs)
 
-
-
-def train_central(model : ImageClassificationBase, train_loader : DataLoader, val_loader : DataLoader, row_exp : dict):
-
+def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_loader: DataLoader, row_exp: dict):
     """ Main training function for centralized learning
     
     Arguments:
@@ -207,18 +200,26 @@ def train_central(model : ImageClassificationBase, train_loader : DataLoader, va
         (model, history) : base model with trained weights / results at each training step
     """
 
-    opt_func=torch.optim.SGD #if row_exp['nn_model'] == "linear" else torch.optim.Adam
+    # Check if CUDA is available and set the device
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    
+    # Move the model to the appropriate device
+    model.to(device)
+
+    opt_func = torch.optim.SGD  # if row_exp['nn_model'] == "linear" else torch.optim.Adam
     lr = 0.01
     history = []
     optimizer = opt_func(model.parameters(), lr)
-    
+
     for epoch in range(row_exp['centralized_epochs']):
         
         model.train()
         train_losses = []
         
         for batch in train_loader:
-
+            # Move batch to the same device as the model
+            batch = [item.to(device) for item in batch]  # Assuming batch is a tuple of tensors
+            
             loss = model.training_step(batch)
             train_losses.append(loss)
             loss.backward()
@@ -226,7 +227,7 @@ def train_central(model : ImageClassificationBase, train_loader : DataLoader, va
             optimizer.step()
             optimizer.zero_grad()
         
-        result = evaluate(model, val_loader)
+        result = evaluate(model, val_loader)  # Ensure evaluate handles CUDA as needed
         result['train_loss'] = torch.stack(train_losses).mean().item()        
         
         model.epoch_end(epoch, result)
@@ -234,11 +235,11 @@ def train_central(model : ImageClassificationBase, train_loader : DataLoader, va
         history.append(result)
     
     return model, history
-    
 
-def test_model(model : nn.Module, test_loader : DataLoader) -> float:
+    
 
-    """ Calcualtes model accuracy (percentage) on the <test_loader> Dataset
+def test_model(model: nn.Module, test_loader: DataLoader) -> float:
+    """ Calculates model accuracy (percentage) on the <test_loader> Dataset
     
     Arguments:
         model : the input server model
@@ -247,28 +248,43 @@ def test_model(model : nn.Module, test_loader : DataLoader) -> float:
     
     criterion = nn.CrossEntropyLoss()
 
-    model.eval()
+    # Set device to CUDA if available
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    
+    # Move the model to the device
+    model.to(device)
+
+    model.eval()  # Set the model to evaluation mode
 
     correct = 0
     total = 0
     test_loss = 0.0
 
-    with torch.no_grad():
+    with torch.no_grad():  # No need to track gradients in evaluation
 
         for inputs, labels in test_loader:
-
+            
+            # Move inputs and labels to the device
+            inputs, labels = inputs.to(device), labels.to(device)
+            
+            # Forward pass
             outputs = model(inputs)
 
+            # Compute the loss
             loss = criterion(outputs, labels)
             test_loss += loss.item() * inputs.size(0)
 
+            # Get predictions
             _, predicted = torch.max(outputs, 1)
 
+            # Calculate total and correct predictions
             total += labels.size(0)
             correct += (predicted == labels).sum().item()
 
+    # Average test loss over all examples
     test_loss = test_loss / len(test_loader.dataset)
 
+    # Calculate accuracy percentage
     accuracy = (correct / total) * 100
 
-    return accuracy
\ No newline at end of file
+    return accuracy
-- 
GitLab