Skip to content
Snippets Groups Projects
Commit eeb4774c authored by Julien Rabault's avatar Julien Rabault
Browse files

Comment

parent ea12a5e7
No related branches found
No related tags found
No related merge requests found
......@@ -11,33 +11,32 @@ def load_obj(name):
class SymbolTokenizer():
def __init__(self, index_to_super):
"""@params tokenizer (PretrainedTokenizer): Tokenizer that tokenizes text """
"""@params index_to_super: Dict for convert ID to tags """
self.index_to_super = index_to_super
self.super_to_index = {v: int(k) for k, v in self.index_to_super.items()}
def lenSuper(self):
"""@return len of dict for convert ID to tags """
return len(self.index_to_super) + 1
def convert_batchs_to_ids(self, tags, sents_tokenized):
encoded_labels = []
labels = [[self.super_to_index[str(symbol)] for symbol in sents] for sents in tags]
for l, s in zip(labels, sents_tokenized):
super_tok = pad_sequence(l,len(s))
super_tok = pad_sequence(l, len(s))
encoded_labels.append(super_tok)
return torch.tensor(encoded_labels)
def convert_ids_to_tags(self, tags_ids):
labels = [[self.index_to_super[int(symbol)] for symbol in sents if self.index_to_super[int(symbol)] != '<unk>'] for sents in tags_ids]
labels = [[self.index_to_super[int(symbol)] for symbol in sents if self.index_to_super[int(symbol)] != '<unk>']
for sents in tags_ids]
return labels
def pad_sequence(sequences, max_len=400):
padded = [0] * max_len
padded[:len(sequences)] = sequences
return padded
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment