diff --git a/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger.py index 2f5dcebbc95a1f7d9c009d37a2c2b74630ede9e2..1108f2192fb47b10910ab840dfba79eaae82ad7c 100644 --- a/SuperTagger/SuperTagger.py +++ b/SuperTagger/SuperTagger.py @@ -171,8 +171,8 @@ class SuperTagger: @return: tags prediction for all sentences (no argmax tags, convert tags, embedding layer of bert ) """ assert self.trainable or self.model is None, "Please use the create_new_model(...) or load_weights(...) " \ - "function before the predict, the model is not integrated " - assert type(sentences) == str or type(sentences) == list[str], "param sentences: list of sentences : list[" \ + "function before the predict, the model is not integrated " + assert type(sentences) == str or type(sentences) == list, "param sentences: list of sentences : list[" \ "str] OR one sentences : str " sentences = [sentences] if type(sentences) == str else sentences diff --git a/SuperTagger/Utils/helpers.py b/SuperTagger/Utils/helpers.py index 7c4a2a010627656972c531f8b57118693ac85150..1262485308aa2983f1e7be8af2378dee80dffe54 100644 --- a/SuperTagger/Utils/helpers.py +++ b/SuperTagger/Utils/helpers.py @@ -35,9 +35,8 @@ def categorical_accuracy_str(preds: list[list[float]], truth: list[list[float]]) for i in range(len(truth)): sublist_truth = truth[i] sublist_preds = preds[i] + nb_label += len(sublist_truth) for j in range(min(len(sublist_truth), len(sublist_preds))): if str(sublist_truth[j]) == str(sublist_preds[j]): good_label += 1 - nb_label += 1 - return good_label / nb_label