Skip to content
Snippets Groups Projects
Commit 52c42792 authored by Leahcimali's avatar Leahcimali
Browse files

Update Training for Cuda use

parent def20d90
No related branches found
No related tags found
No related merge requests found
...@@ -177,14 +177,21 @@ def train_federated(main_model, list_clients, row_exp, use_cluster_models = Fals ...@@ -177,14 +177,21 @@ def train_federated(main_model, list_clients, row_exp, use_cluster_models = Fals
return main_model return main_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@torch.no_grad() @torch.no_grad()
def evaluate(model : nn.Module, val_loader : DataLoader) -> dict: def evaluate(model : nn.Module, val_loader : DataLoader) -> dict:
""" Returns a dict with loss and accuracy information""" """ Returns a dict with loss and accuracy information"""
model.to(device)
model.eval() model.eval()
outputs = [model.validation_step(batch) for batch in val_loader] outputs =[]
for batch in val_loader:
# Move entire batch to the correct device
batch = [item.to(device) for item in batch]
# Call the validation step and append to outputs
output = model.validation_step(batch,device)
outputs.append(output)
return model.validation_epoch_end(outputs) return model.validation_epoch_end(outputs)
def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_loader: DataLoader, row_exp: dict): def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_loader: DataLoader, row_exp: dict):
...@@ -201,13 +208,12 @@ def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_ ...@@ -201,13 +208,12 @@ def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_
""" """
# Check if CUDA is available and set the device # Check if CUDA is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move the model to the appropriate device # Move the model to the appropriate device
model.to(device) model.to(device)
opt_func = torch.optim.SGD # if row_exp['nn_model'] == "linear" else torch.optim.Adam opt_func = torch.optim.Adam # if row_exp['nn_model'] == "linear" else torch.optim.Adam
lr = 0.01 lr = 0.001
history = [] history = []
optimizer = opt_func(model.parameters(), lr) optimizer = opt_func(model.parameters(), lr)
...@@ -218,15 +224,16 @@ def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_ ...@@ -218,15 +224,16 @@ def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_
for batch in train_loader: for batch in train_loader:
# Move batch to the same device as the model # Move batch to the same device as the model
batch = [item.to(device) for item in batch] # Assuming batch is a tuple of tensors inputs, labels = [item.to(device) for item in batch]
loss = model.training_step(batch) # Pass the unpacked inputs and labels to the model's training step
loss = model.training_step((inputs, labels),device)
train_losses.append(loss) train_losses.append(loss)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
result = evaluate(model, val_loader) # Ensure evaluate handles CUDA as needed result = evaluate(model, val_loader) # Ensure evaluate handles CUDA as needed
result['train_loss'] = torch.stack(train_losses).mean().item() result['train_loss'] = torch.stack(train_losses).mean().item()
...@@ -262,12 +269,8 @@ def test_model(model: nn.Module, test_loader: DataLoader) -> float: ...@@ -262,12 +269,8 @@ def test_model(model: nn.Module, test_loader: DataLoader) -> float:
with torch.no_grad(): # No need to track gradients in evaluation with torch.no_grad(): # No need to track gradients in evaluation
for inputs, labels in test_loader: for batch in test_loader:
inputs, labels = [item.to(device) for item in batch]
# Move inputs and labels to the device
inputs, labels = inputs.to(device), labels.to(device)
# Forward pass
outputs = model(inputs) outputs = model(inputs)
# Compute the loss # Compute the loss
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment