From 50aa065d3ea6e6d2ed1352545670a0e8d0a31a10 Mon Sep 17 00:00:00 2001 From: PNRIA - Julien <julien.rabault@irit.fr> Date: Wed, 18 May 2022 10:52:59 +0200 Subject: [PATCH] fix epoch --- SuperTagger/SuperTagger.py | 10 +++++----- train.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger.py index 8e211f7..a55c8f2 100644 --- a/SuperTagger/SuperTagger.py +++ b/SuperTagger/SuperTagger.py @@ -236,21 +236,21 @@ class SuperTagger: eval_accuracy, eval_loss, nb_eval_steps = self.__eval_epoch(validation_dataloader) print("") - print(f'Epoch: {epoch_i:02} | Epoch Time: {training_time}') + print(f'Epoch: {epoch_i+1:02} | Epoch Time: {training_time}') print(f'\tTrain Loss: {epoch_loss:.3f} | Train Acc: {epoch_acc * 100:.2f}%') if validation_rate > 0.0: print(f'\tVal Loss: {eval_loss:.3f} | Val Acc: {eval_accuracy * 100:.2f}%') if tensorboard: writer.add_scalars(f'Accuracy', { - 'Train': epoch_acc}, epoch_i) + 'Train': epoch_acc}, epoch_i+1) writer.add_scalars(f'Loss', { - 'Train': epoch_loss}, epoch_i) + 'Train': epoch_loss}, epoch_i+1) if validation_rate > 0.0: writer.add_scalars(f'Accuracy', { - 'Validation': eval_accuracy}, epoch_i) + 'Validation': eval_accuracy}, epoch_i+1) writer.add_scalars(f'Loss', { - 'Validation': eval_loss}, epoch_i) + 'Validation': eval_loss}, epoch_i+1) self.epoch_i += 1 diff --git a/train.py b/train.py index d3cb4ef..9ac77d4 100644 --- a/train.py +++ b/train.py @@ -4,7 +4,7 @@ from SuperTagger.Utils.helpers import read_csv_pgbar, load_obj #### DATA #### file_path = 'Datasets/m2_dataset.csv' -df = read_csv_pgbar(file_path,1000) +df = read_csv_pgbar(file_path,100) texts = df['X'].tolist() tags = df['Y1'].tolist() -- GitLab