From 34f46ef630915c047c827e8b0603bbe5f14e36d3 Mon Sep 17 00:00:00 2001
From: PNRIA - Julien <julien.rabault@irit.fr>
Date: Wed, 29 Jun 2022 15:41:32 +0200
Subject: [PATCH] Correct assert

---
 SuperTagger/SuperTagger.py   | 4 ++--
 SuperTagger/Utils/helpers.py | 3 +--
 2 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger.py
index 2f5dceb..1108f21 100644
--- a/SuperTagger/SuperTagger.py
+++ b/SuperTagger/SuperTagger.py
@@ -171,8 +171,8 @@ class SuperTagger:
         @return: tags prediction for all sentences (no argmax tags, convert tags, embedding layer of bert )
         """
         assert self.trainable or self.model is None, "Please use the create_new_model(...) or load_weights(...) " \
-                                                     "function before the predict, the model is not integrated "
-        assert type(sentences) == str or type(sentences) == list[str], "param sentences: list of sentences : list[" \
+                                                        "function before the predict, the model is not integrated "
+        assert type(sentences) == str or type(sentences) == list, "param sentences: list of sentences : list[" \
                                                                        "str] OR one sentences : str "
         sentences = [sentences] if type(sentences) == str else sentences
 
diff --git a/SuperTagger/Utils/helpers.py b/SuperTagger/Utils/helpers.py
index 7c4a2a0..1262485 100644
--- a/SuperTagger/Utils/helpers.py
+++ b/SuperTagger/Utils/helpers.py
@@ -35,9 +35,8 @@ def categorical_accuracy_str(preds: list[list[float]], truth: list[list[float]])
     for i in range(len(truth)):
         sublist_truth = truth[i]
         sublist_preds = preds[i]
+        nb_label += len(sublist_truth)
         for j in range(min(len(sublist_truth), len(sublist_preds))):
             if str(sublist_truth[j]) == str(sublist_preds[j]):
                 good_label += 1
-            nb_label += 1
-
     return good_label / nb_label
-- 
GitLab