From 236d1dee61f77d31b869d77bac72239497535c33 Mon Sep 17 00:00:00 2001
From: emetheni <lenakmeth@gmail.com>
Date: Mon, 22 May 2023 15:12:52 +0200
Subject: [PATCH] gradient accumulation bug

---
 pytorch_classifier.py | 39 ++++++++++++++++++++++++++++-----------
 1 file changed, 28 insertions(+), 11 deletions(-)

diff --git a/pytorch_classifier.py b/pytorch_classifier.py
index 808c4fe..0fcc27a 100644
--- a/pytorch_classifier.py
+++ b/pytorch_classifier.py
@@ -164,25 +164,42 @@ def train(model,
 
         total_acc_train = 0
         total_loss_train = 0
-
+        batch_counter = 0
+        
         for train_input, train_label in tqdm(train_dataloader):
+            batch_counter += 1
             train_label = train_label.to(device)
             mask = train_input['attention_mask'].to(device)
             input_id = train_input['input_ids'].squeeze(1).to(device)
 
             output = model(input_id, mask)
                 
-            batch_loss = criterion(output, train_label.long())
-            total_loss_train += batch_loss.item()
+#             batch_loss = criterion(output, train_label.long())
+#             total_loss_train += batch_loss.item()
                 
-            acc = (output.argmax(dim=1) == train_label).sum().item()
-            total_acc_train += acc
-
-            model.zero_grad()
-            batch_loss.backward()
-            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
-            optimizer.step()
-            scheduler.step()
+#             acc = (output.argmax(dim=1) == train_label).sum().item()
+#             total_acc_train += acc
+            
+            # Compute Loss and Perform Back-propagation
+            loss = criterion(output, train_label.long())
+
+
+            # Normalize the Gradients
+            loss = loss / gradient_accumulation_steps
+            loss.backward()
+
+            
+            if (batch_counter % gradient_accumulation_steps == 0):
+                # Update Optimizer
+                optimizer.step() # or flip them?
+                optimizer.zero_grad()
+                
+            
+                model.zero_grad()
+#             loss.backward()
+                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+#             optimizer.step()
+                scheduler.step()
             
         # ------ Validation --------
         print('\nValidation for epoch:', epoch_num + 1)
-- 
GitLab