diff --git a/SuperTagger/Utils/SymbolTokenizer.py b/SuperTagger/Utils/SymbolTokenizer.py index eaf5ae4ba7526eaaa3c1f52d0caa8fc7439646c5..e5095d1acc392ca5cab3cf31142c3b8143964e9e 100644 --- a/SuperTagger/Utils/SymbolTokenizer.py +++ b/SuperTagger/Utils/SymbolTokenizer.py @@ -11,33 +11,32 @@ def load_obj(name): class SymbolTokenizer(): - def __init__(self, index_to_super): - """@params tokenizer (PretrainedTokenizer): Tokenizer that tokenizes text """ + """@params index_to_super: Dict for convert ID to tags """ self.index_to_super = index_to_super self.super_to_index = {v: int(k) for k, v in self.index_to_super.items()} - def lenSuper(self): + """@return len of dict for convert ID to tags """ return len(self.index_to_super) + 1 def convert_batchs_to_ids(self, tags, sents_tokenized): encoded_labels = [] labels = [[self.super_to_index[str(symbol)] for symbol in sents] for sents in tags] for l, s in zip(labels, sents_tokenized): - super_tok = pad_sequence(l,len(s)) + super_tok = pad_sequence(l, len(s)) encoded_labels.append(super_tok) return torch.tensor(encoded_labels) def convert_ids_to_tags(self, tags_ids): - labels = [[self.index_to_super[int(symbol)] for symbol in sents if self.index_to_super[int(symbol)] != '<unk>'] for sents in tags_ids] + labels = [[self.index_to_super[int(symbol)] for symbol in sents if self.index_to_super[int(symbol)] != '<unk>'] + for sents in tags_ids] return labels + def pad_sequence(sequences, max_len=400): padded = [0] * max_len padded[:len(sequences)] = sequences return padded - -