Skip to content
Snippets Groups Projects
Commit 23e9ecb2 authored by Alice Pain's avatar Alice Pain
Browse files

minor

parent b0a2c22a
Branches
No related tags found
No related merge requests found
...@@ -43,8 +43,9 @@ class LSTM(nn.Module): ...@@ -43,8 +43,9 @@ class LSTM(nn.Module):
class SentenceBatch(): class SentenceBatch():
def __init__(self, sentence_ids, tok_ids, tok_types, tok_masks, labels): def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels):
self.sentence_ids = sentence_ids self.sentence_ids = sentence_ids
self.tokens = tokens
self.tok_ids = pad_sequence(tok_ids, batch_first=True) self.tok_ids = pad_sequence(tok_ids, batch_first=True)
self.tok_types = pad_sequence(tok_types, batch_first=True) self.tok_types = pad_sequence(tok_types, batch_first=True)
self.tok_masks = pad_sequence(tok_masks, batch_first=True) self.tok_masks = pad_sequence(tok_masks, batch_first=True)
...@@ -120,7 +121,7 @@ def collate_batch(batch): ...@@ -120,7 +121,7 @@ def collate_batch(batch):
tok_types = [make_tok_types(l) for l in lengths] tok_types = [make_tok_types(l) for l in lengths]
tok_masks = [make_tok_masks(l) for l in lengths] tok_masks = [make_tok_masks(l) for l in lengths]
return SentenceBatch(sentence_ids, tok_ids, tok_types, tok_masks, labels) return SentenceBatch(sentence_ids, token_batch, tok_ids, tok_types, tok_masks, labels)
def train(corpus, fmt): def train(corpus, fmt):
print(f'starting training of {corpus} in format {fmt}...') print(f'starting training of {corpus} in format {fmt}...')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment