From eeb4774c071e03f460f48798ab8d6820395825c9 Mon Sep 17 00:00:00 2001 From: PNRIA - Julien <julien.rabault@irit.fr> Date: Wed, 18 May 2022 10:05:16 +0200 Subject: [PATCH] Comment --- SuperTagger/Utils/SymbolTokenizer.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/SuperTagger/Utils/SymbolTokenizer.py b/SuperTagger/Utils/SymbolTokenizer.py index eaf5ae4..e5095d1 100644 --- a/SuperTagger/Utils/SymbolTokenizer.py +++ b/SuperTagger/Utils/SymbolTokenizer.py @@ -11,33 +11,32 @@ def load_obj(name): class SymbolTokenizer(): - def __init__(self, index_to_super): - """@params tokenizer (PretrainedTokenizer): Tokenizer that tokenizes text """ + """@params index_to_super: Dict for convert ID to tags """ self.index_to_super = index_to_super self.super_to_index = {v: int(k) for k, v in self.index_to_super.items()} - def lenSuper(self): + """@return len of dict for convert ID to tags """ return len(self.index_to_super) + 1 def convert_batchs_to_ids(self, tags, sents_tokenized): encoded_labels = [] labels = [[self.super_to_index[str(symbol)] for symbol in sents] for sents in tags] for l, s in zip(labels, sents_tokenized): - super_tok = pad_sequence(l,len(s)) + super_tok = pad_sequence(l, len(s)) encoded_labels.append(super_tok) return torch.tensor(encoded_labels) def convert_ids_to_tags(self, tags_ids): - labels = [[self.index_to_super[int(symbol)] for symbol in sents if self.index_to_super[int(symbol)] != '<unk>'] for sents in tags_ids] + labels = [[self.index_to_super[int(symbol)] for symbol in sents if self.index_to_super[int(symbol)] != '<unk>'] + for sents in tags_ids] return labels + def pad_sequence(sequences, max_len=400): padded = [0] * max_len padded[:len(sequences)] = sequences return padded - - -- GitLab