From 34f46ef630915c047c827e8b0603bbe5f14e36d3 Mon Sep 17 00:00:00 2001 From: PNRIA - Julien <julien.rabault@irit.fr> Date: Wed, 29 Jun 2022 15:41:32 +0200 Subject: [PATCH] Correct assert --- SuperTagger/SuperTagger.py | 4 ++-- SuperTagger/Utils/helpers.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger.py index 2f5dceb..1108f21 100644 --- a/SuperTagger/SuperTagger.py +++ b/SuperTagger/SuperTagger.py @@ -171,8 +171,8 @@ class SuperTagger: @return: tags prediction for all sentences (no argmax tags, convert tags, embedding layer of bert ) """ assert self.trainable or self.model is None, "Please use the create_new_model(...) or load_weights(...) " \ - "function before the predict, the model is not integrated " - assert type(sentences) == str or type(sentences) == list[str], "param sentences: list of sentences : list[" \ + "function before the predict, the model is not integrated " + assert type(sentences) == str or type(sentences) == list, "param sentences: list of sentences : list[" \ "str] OR one sentences : str " sentences = [sentences] if type(sentences) == str else sentences diff --git a/SuperTagger/Utils/helpers.py b/SuperTagger/Utils/helpers.py index 7c4a2a0..1262485 100644 --- a/SuperTagger/Utils/helpers.py +++ b/SuperTagger/Utils/helpers.py @@ -35,9 +35,8 @@ def categorical_accuracy_str(preds: list[list[float]], truth: list[list[float]]) for i in range(len(truth)): sublist_truth = truth[i] sublist_preds = preds[i] + nb_label += len(sublist_truth) for j in range(min(len(sublist_truth), len(sublist_preds))): if str(sublist_truth[j]) == str(sublist_preds[j]): good_label += 1 - nb_label += 1 - return good_label / nb_label -- GitLab