From 50aa065d3ea6e6d2ed1352545670a0e8d0a31a10 Mon Sep 17 00:00:00 2001
From: PNRIA - Julien <julien.rabault@irit.fr>
Date: Wed, 18 May 2022 10:52:59 +0200
Subject: [PATCH] fix epoch

---
 SuperTagger/SuperTagger.py | 10 +++++-----
 train.py                   |  2 +-
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger.py
index 8e211f7..a55c8f2 100644
--- a/SuperTagger/SuperTagger.py
+++ b/SuperTagger/SuperTagger.py
@@ -236,21 +236,21 @@ class SuperTagger:
                 eval_accuracy, eval_loss, nb_eval_steps = self.__eval_epoch(validation_dataloader)
 
             print("")
-            print(f'Epoch: {epoch_i:02} | Epoch Time: {training_time}')
+            print(f'Epoch: {epoch_i+1:02} | Epoch Time: {training_time}')
             print(f'\tTrain Loss: {epoch_loss:.3f} | Train Acc: {epoch_acc * 100:.2f}%')
             if validation_rate > 0.0:
                 print(f'\tVal Loss: {eval_loss:.3f} | Val Acc: {eval_accuracy * 100:.2f}%')
 
             if tensorboard:
                 writer.add_scalars(f'Accuracy', {
-                    'Train': epoch_acc}, epoch_i)
+                    'Train': epoch_acc}, epoch_i+1)
                 writer.add_scalars(f'Loss', {
-                    'Train': epoch_loss}, epoch_i)
+                    'Train': epoch_loss}, epoch_i+1)
                 if validation_rate > 0.0:
                     writer.add_scalars(f'Accuracy', {
-                        'Validation': eval_accuracy}, epoch_i)
+                        'Validation': eval_accuracy}, epoch_i+1)
                     writer.add_scalars(f'Loss', {
-                        'Validation': eval_loss}, epoch_i)
+                        'Validation': eval_loss}, epoch_i+1)
 
             self.epoch_i += 1
 
diff --git a/train.py b/train.py
index d3cb4ef..9ac77d4 100644
--- a/train.py
+++ b/train.py
@@ -4,7 +4,7 @@ from SuperTagger.Utils.helpers import read_csv_pgbar, load_obj
 #### DATA ####
 file_path = 'Datasets/m2_dataset.csv'
 
-df = read_csv_pgbar(file_path,1000)
+df = read_csv_pgbar(file_path,100)
 
 texts = df['X'].tolist()
 tags = df['Y1'].tolist()
-- 
GitLab