Skip to content
Snippets Groups Projects
Commit 50aa065d authored by Julien Rabault's avatar Julien Rabault
Browse files

fix epoch

parent a7008a57
No related branches found
No related tags found
No related merge requests found
...@@ -236,21 +236,21 @@ class SuperTagger: ...@@ -236,21 +236,21 @@ class SuperTagger:
eval_accuracy, eval_loss, nb_eval_steps = self.__eval_epoch(validation_dataloader) eval_accuracy, eval_loss, nb_eval_steps = self.__eval_epoch(validation_dataloader)
print("") print("")
print(f'Epoch: {epoch_i:02} | Epoch Time: {training_time}') print(f'Epoch: {epoch_i+1:02} | Epoch Time: {training_time}')
print(f'\tTrain Loss: {epoch_loss:.3f} | Train Acc: {epoch_acc * 100:.2f}%') print(f'\tTrain Loss: {epoch_loss:.3f} | Train Acc: {epoch_acc * 100:.2f}%')
if validation_rate > 0.0: if validation_rate > 0.0:
print(f'\tVal Loss: {eval_loss:.3f} | Val Acc: {eval_accuracy * 100:.2f}%') print(f'\tVal Loss: {eval_loss:.3f} | Val Acc: {eval_accuracy * 100:.2f}%')
if tensorboard: if tensorboard:
writer.add_scalars(f'Accuracy', { writer.add_scalars(f'Accuracy', {
'Train': epoch_acc}, epoch_i) 'Train': epoch_acc}, epoch_i+1)
writer.add_scalars(f'Loss', { writer.add_scalars(f'Loss', {
'Train': epoch_loss}, epoch_i) 'Train': epoch_loss}, epoch_i+1)
if validation_rate > 0.0: if validation_rate > 0.0:
writer.add_scalars(f'Accuracy', { writer.add_scalars(f'Accuracy', {
'Validation': eval_accuracy}, epoch_i) 'Validation': eval_accuracy}, epoch_i+1)
writer.add_scalars(f'Loss', { writer.add_scalars(f'Loss', {
'Validation': eval_loss}, epoch_i) 'Validation': eval_loss}, epoch_i+1)
self.epoch_i += 1 self.epoch_i += 1
......
...@@ -4,7 +4,7 @@ from SuperTagger.Utils.helpers import read_csv_pgbar, load_obj ...@@ -4,7 +4,7 @@ from SuperTagger.Utils.helpers import read_csv_pgbar, load_obj
#### DATA #### #### DATA ####
file_path = 'Datasets/m2_dataset.csv' file_path = 'Datasets/m2_dataset.csv'
df = read_csv_pgbar(file_path,1000) df = read_csv_pgbar(file_path,100)
texts = df['X'].tolist() texts = df['X'].tolist()
tags = df['Y1'].tolist() tags = df['Y1'].tolist()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment