diff --git a/pytorch_classifier.py b/pytorch_classifier.py index 808c4fe56187e7747c6fdeebe0120d0e315e1730..0fcc27acdf4d18b35ff6e413505309dcd6f0634d 100644 --- a/pytorch_classifier.py +++ b/pytorch_classifier.py @@ -164,25 +164,42 @@ def train(model, total_acc_train = 0 total_loss_train = 0 - + batch_counter = 0 + for train_input, train_label in tqdm(train_dataloader): + batch_counter += 1 train_label = train_label.to(device) mask = train_input['attention_mask'].to(device) input_id = train_input['input_ids'].squeeze(1).to(device) output = model(input_id, mask) - batch_loss = criterion(output, train_label.long()) - total_loss_train += batch_loss.item() +# batch_loss = criterion(output, train_label.long()) +# total_loss_train += batch_loss.item() - acc = (output.argmax(dim=1) == train_label).sum().item() - total_acc_train += acc - - model.zero_grad() - batch_loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) - optimizer.step() - scheduler.step() +# acc = (output.argmax(dim=1) == train_label).sum().item() +# total_acc_train += acc + + # Compute Loss and Perform Back-propagation + loss = criterion(output, train_label.long()) + + + # Normalize the Gradients + loss = loss / gradient_accumulation_steps + loss.backward() + + + if (batch_counter % gradient_accumulation_steps == 0): + # Update Optimizer + optimizer.step() # or flip them? + optimizer.zero_grad() + + + model.zero_grad() +# loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) +# optimizer.step() + scheduler.step() # ------ Validation -------- print('\nValidation for epoch:', epoch_num + 1)