diff --git a/pytorch_classifier.py b/pytorch_classifier.py
index 808c4fe56187e7747c6fdeebe0120d0e315e1730..0fcc27acdf4d18b35ff6e413505309dcd6f0634d 100644
--- a/pytorch_classifier.py
+++ b/pytorch_classifier.py
@@ -164,25 +164,42 @@ def train(model,
 
         total_acc_train = 0
         total_loss_train = 0
-
+        batch_counter = 0
+        
         for train_input, train_label in tqdm(train_dataloader):
+            batch_counter += 1
             train_label = train_label.to(device)
             mask = train_input['attention_mask'].to(device)
             input_id = train_input['input_ids'].squeeze(1).to(device)
 
             output = model(input_id, mask)
                 
-            batch_loss = criterion(output, train_label.long())
-            total_loss_train += batch_loss.item()
+#             batch_loss = criterion(output, train_label.long())
+#             total_loss_train += batch_loss.item()
                 
-            acc = (output.argmax(dim=1) == train_label).sum().item()
-            total_acc_train += acc
-
-            model.zero_grad()
-            batch_loss.backward()
-            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
-            optimizer.step()
-            scheduler.step()
+#             acc = (output.argmax(dim=1) == train_label).sum().item()
+#             total_acc_train += acc
+            
+            # Compute Loss and Perform Back-propagation
+            loss = criterion(output, train_label.long())
+
+
+            # Normalize the Gradients
+            loss = loss / gradient_accumulation_steps
+            loss.backward()
+
+            
+            if (batch_counter % gradient_accumulation_steps == 0):
+                # Update Optimizer
+                optimizer.step() # or flip them?
+                optimizer.zero_grad()
+                
+            
+                model.zero_grad()
+#             loss.backward()
+                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+#             optimizer.step()
+                scheduler.step()
             
         # ------ Validation --------
         print('\nValidation for epoch:', epoch_num + 1)