Skip to content
Snippets Groups Projects
Commit 52c42792 authored by Leahcimali's avatar Leahcimali
Browse files

Update Training for Cuda use

parent def20d90
Branches
No related tags found
No related merge requests found
......@@ -177,14 +177,21 @@ def train_federated(main_model, list_clients, row_exp, use_cluster_models = Fals
return main_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@torch.no_grad()
def evaluate(model : nn.Module, val_loader : DataLoader) -> dict:
""" Returns a dict with loss and accuracy information"""
model.to(device)
model.eval()
outputs = [model.validation_step(batch) for batch in val_loader]
outputs =[]
for batch in val_loader:
# Move entire batch to the correct device
batch = [item.to(device) for item in batch]
# Call the validation step and append to outputs
output = model.validation_step(batch,device)
outputs.append(output)
return model.validation_epoch_end(outputs)
def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_loader: DataLoader, row_exp: dict):
......@@ -201,13 +208,12 @@ def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_
"""
# Check if CUDA is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move the model to the appropriate device
model.to(device)
opt_func = torch.optim.SGD # if row_exp['nn_model'] == "linear" else torch.optim.Adam
lr = 0.01
opt_func = torch.optim.Adam # if row_exp['nn_model'] == "linear" else torch.optim.Adam
lr = 0.001
history = []
optimizer = opt_func(model.parameters(), lr)
......@@ -218,15 +224,16 @@ def train_central(model: ImageClassificationBase, train_loader: DataLoader, val_
for batch in train_loader:
# Move batch to the same device as the model
batch = [item.to(device) for item in batch] # Assuming batch is a tuple of tensors
loss = model.training_step(batch)
inputs, labels = [item.to(device) for item in batch]
# Pass the unpacked inputs and labels to the model's training step
loss = model.training_step((inputs, labels),device)
train_losses.append(loss)
loss.backward()
optimizer.step()
optimizer.zero_grad()
result = evaluate(model, val_loader) # Ensure evaluate handles CUDA as needed
result['train_loss'] = torch.stack(train_losses).mean().item()
......@@ -262,12 +269,8 @@ def test_model(model: nn.Module, test_loader: DataLoader) -> float:
with torch.no_grad(): # No need to track gradients in evaluation
for inputs, labels in test_loader:
# Move inputs and labels to the device
inputs, labels = inputs.to(device), labels.to(device)
# Forward pass
for batch in test_loader:
inputs, labels = [item.to(device) for item in batch]
outputs = model(inputs)
# Compute the loss
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment