diff --git a/trytorch/train_model_baseline.py b/trytorch/train_model_baseline.py index c772b818c5f0b950ad6c729a7d919224e407356a..7b9dfecc7b5c785a529975c2d6d9ed60baa45256 100644 --- a/trytorch/train_model_baseline.py +++ b/trytorch/train_model_baseline.py @@ -1,237 +1,230 @@ import numpy as np import sys import os +import argparse from transformers import BertTokenizer, BertModel from transformers.tokenization_utils_base import BatchEncoding import torch from torch import nn from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader -#from torchmetrics import F1Score -#from torch.autograd import Variable from tqdm import tqdm bert = 'bert-base-multilingual-cased' -#bert_embeddings = BertModel.from_pretrained(bert) +tokenizer = BertTokenizer.from_pretrained(bert) class LSTM(nn.Module): - def __init__(self, batch_size, input_size, hidden_size, num_layers=1, bidirectional=False): - super().__init__() - self.batch_size = batch_size - self.hidden_size = hidden_size - self.bert_embeddings = BertModel.from_pretrained(bert) - self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional) - d = 2 if bidirectional else 1 - self.hiddenToLabel = nn.Linear(d * hidden_size, 1) - self.act = nn.Sigmoid() - - def forward(self, batch): - output = self.bert_embeddings(**batch) - output768 = output.last_hidden_state - output64, (last_hidden_state, last_cell_state) = self.lstm(output768) - #output64, self.hidden = self.lstm(output768, self.hidden) - #print("output64=", output64.shape) - #print("last_hidden_state", last_hidden_state.shape) - #print("last_cell_state", last_cell_state.shape) - output1 = self.hiddenToLabel(output64) - #print("output1=", output1.shape) - return self.act(output1[:,:,0]) - - #lstm_out, self.hidden = self.lstm(output, self.hidden) + def __init__(self, batch_size, input_size, hidden_size, n_layers=1, bidirectional=False): + super().__init__() + self.batch_size = batch_size + self.hidden_size = hidden_size + self.bert_embeddings = BertModel.from_pretrained(bert) + self.lstm = nn.LSTM(input_size, hidden_size, n_layers, batch_first=True, bidirectional=bidirectional) + d = 2 if bidirectional else 1 + self.hiddenToLabel = nn.Linear(d * hidden_size, 1) + self.act = nn.Sigmoid() + + def forward(self, batch): + output = self.bert_embeddings(**batch) + output768 = output.last_hidden_state + output64, (last_hidden_state, last_cell_state) = self.lstm(output768) + #output64, self.hidden = self.lstm(output768, self.hidden) + #print("output64=", output64.shape) + #print("last_hidden_state", last_hidden_state.shape) + #print("last_cell_state", last_cell_state.shape) + output1 = self.hiddenToLabel(output64) + #print("output1=", output1.shape) + return self.act(output1[:,:,0]) + + #lstm_out, self.hidden = self.lstm(output, self.hidden) class SentenceBatch(): - def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels, uposes = None, deprels = None, dheads = None): - self.sentence_ids = sentence_ids - self.tokens = tokens - self.tok_ids = pad_sequence(tok_ids, batch_first=True) #bert token ids - self.tok_types = pad_sequence(tok_types, batch_first=True) - self.tok_masks = pad_sequence(tok_masks, batch_first=True) - self.labels = pad_sequence(labels, batch_first=True) - if uposes is not None: - self.uposes = pad_sequence(uposes, batch_first=True) - else: self.uposes = None - if deprels is not None: - self.deprels = pad_sequence(deprels, batch_first=True) - else: self.deprels = None - if dheads is not None: - self.dheads = pad_sequence(dheads, batch_first=True) - else: self.dheads = None - self.labels = pad_sequence(labels, batch_first=True) - - def getBatchEncoding(self): - dico = { 'input_ids': self.tok_ids, 'token_type_ids': self.tok_types, 'attention_mask': self.tok_masks } - return BatchEncoding(dico) - + def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels, uposes = None, deprels = None, dheads = None): + self.sentence_ids = sentence_ids + self.tokens = tokens + self.tok_ids = pad_sequence(tok_ids, batch_first=True) #bert token ids + self.tok_types = pad_sequence(tok_types, batch_first=True) + self.tok_masks = pad_sequence(tok_masks, batch_first=True) + self.labels = pad_sequence(labels, batch_first=True) + if uposes is not None: + self.uposes = pad_sequence(uposes, batch_first=True) + else: self.uposes = None + if deprels is not None: + self.deprels = pad_sequence(deprels, batch_first=True) + else: self.deprels = None + if dheads is not None: + self.dheads = pad_sequence(dheads, batch_first=True) + else: self.dheads = None + self.labels = pad_sequence(labels, batch_first=True) + + def getBatchEncoding(self): + dico = { 'input_ids': self.tok_ids, 'token_type_ids': self.tok_types, 'attention_mask': self.tok_masks } + return BatchEncoding(dico) + def generate_sentence_list(corpus, sset, fmt): - #move that part to parse_corpus.py - parsed_data = os.path.join("parsed_data", f"parsed_{corpus}_{sset}.{fmt}.npz") - #print("PATH", parsed_data) - if not os.path.isfile(parsed_data): - print("you must parse the corpus before training it") - sys.exit() - data = np.load(parsed_data, allow_pickle = True) - tok_ids, toks, labels = [data[f] for f in data.files] - - sentences = [] - - if fmt == 'conllu': - previous_i = 0 - for index, tok_id in enumerate(tok_ids): - if index > 0 and (tok_id == 1 or tok_id == '1' or (isinstance(tok_id, str) and tok_id.startswith('1-'))): - if index - previous_i <= 510: - sentences.append((toks[previous_i:index], labels[previous_i:index])) - previous_i = index - else: - sep = previous_i + 510 - sentences.append((toks[previous_i:sep], labels[previous_i:sep])) - if sep - previous_i > 510: - print("still too long sentence...") - sys.exit() - sentences.append((toks[sep:index], labels[sep:index])) - else: - max_length = 510 - nb_toks = len(tok_ids) - for i in range(0, nb_toks - max_length, max_length): - #add slices of 510 tokens - sentences.append((toks[i:(i + max_length)], labels[i:(i + max_length)])) - sentences.append((toks[-(nb_toks % max_length):], labels[-(nb_toks % max_length):])) - - indexed_sentences = [0] * len(sentences) - for i, sentence in enumerate(sentences): - indexed_sentences[i] = (i, sentence) - - return indexed_sentences + #move that part to parse_corpus.py + parsed_data = os.path.join("parsed_data", f"parsed_{corpus}_{sset}.{fmt}.npz") + #print("PATH", parsed_data) + if not os.path.isfile(parsed_data): + print("you must parse the corpus before training it") + sys.exit() + data = np.load(parsed_data, allow_pickle = True) + tok_ids, toks, labels = [data[f] for f in data.files] + + sentences = [] + + if fmt == 'conllu': + previous_i = 0 + for index, tok_id in enumerate(tok_ids): + if index > 0 and (tok_id == 1 or tok_id == '1' or (isinstance(tok_id, str) and tok_id.startswith('1-'))): + if index - previous_i <= 510: + sentences.append((toks[previous_i:index], labels[previous_i:index])) + previous_i = index + else: + sep = previous_i + 510 + sentences.append((toks[previous_i:sep], labels[previous_i:sep])) + if sep - previous_i > 510: + print("still too long sentence...") + sys.exit() + sentences.append((toks[sep:index], labels[sep:index])) + else: + max_length = 510 + nb_toks = len(tok_ids) + for i in range(0, nb_toks - max_length, max_length): + #add slices of 510 tokens + sentences.append((toks[i:(i + max_length)], labels[i:(i + max_length)])) + sentences.append((toks[-(nb_toks % max_length):], labels[-(nb_toks % max_length):])) + + indexed_sentences = [0] * len(sentences) + for i, sentence in enumerate(sentences): + indexed_sentences[i] = (i, sentence) + + return indexed_sentences def add_cls_sep(sentence): - return ['[CLS]'] + list(sentence) + ['[SEP]'] - + return ['[CLS]'] + list(sentence) + ['[SEP]'] + def toks_to_ids(sentence): - #print("sentence=", sentence) - tokens = ['[CLS]'] + list(sentence) + ['[SEP]'] - #print("tokens=", tokens) - tokenizer = BertTokenizer.from_pretrained(bert) - return torch.tensor(tokenizer.convert_tokens_to_ids(tokens)) #len(tokens) - #print("token_ids=", token_ids) + #print("sentence=", sentence) + tokens = ['[CLS]'] + list(sentence) + ['[SEP]'] + #print("tokens=", tokens) + return torch.tensor(tokenizer.convert_tokens_to_ids(tokens)) #len(tokens) + #print("token_ids=", token_ids) def make_labels(sentence): - zero = np.array([0]) - add_two = np.concatenate((np.concatenate((zero, sentence)), zero)) #add label 0 for [CLS] and [SEP] - return torch.from_numpy(add_two).float() + zero = np.array([0]) + add_two = np.concatenate((np.concatenate((zero, sentence)), zero)) #add label 0 for [CLS] and [SEP] + return torch.from_numpy(add_two).float() def make_tok_types(l): - return torch.zeros(l, dtype=torch.int32) + return torch.zeros(l, dtype=torch.int32) def make_tok_masks(l): - return torch.ones(l, dtype=torch.int32) + return torch.ones(l, dtype=torch.int32) def collate_batch(batch): - sentence_ids, token_batch, label_batch = [i for i, (_, _) in batch], [j for _, (j, _) in batch], [k for _, (_, k) in batch] - #mappings = [make_mapping(sentence) for sentence in token_batch] - labels = [make_labels(sentence) for sentence in label_batch] - tokens = [add_cls_sep(sentence) for sentence in token_batch] - tok_ids = [toks_to_ids(sentence) for sentence in token_batch] - lengths = [len(toks) for toks in tok_ids] - tok_types = [make_tok_types(l) for l in lengths] - tok_masks = [make_tok_masks(l) for l in lengths] - - return SentenceBatch(sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels) - + sentence_ids, token_batch, label_batch = [0] * len(batch), [0] * len(batch), [0] * len(batch) + + for i, (ids, (toks, labs)) in enumerate(batch): + sentence_ids[i] = ids + token_batch[i] = toks + label_batch[i] = labs + + labels = [make_labels(sentence) for sentence in label_batch] + tokens = [add_cls_sep(sentence) for sentence in token_batch] + tok_ids = [torch.tensor(tokenizer.convert_tokens_to_ids(sentence)) for sentence in tokens] + lengths = [len(toks) for toks in tok_ids] + tok_types = [make_tok_types(l) for l in lengths] + tok_masks = [make_tok_masks(l) for l in lengths] + + return SentenceBatch(sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels) + def train(corpus, fmt): - print(f'starting training of {corpus} in format {fmt}...') - data = generate_sentence_list(corpus, 'train', fmt) - - torch.manual_seed(1) - - #PARAMETERS - batch_size = 32 if fmt == 'conllu' else 4 - num_epochs = 10 - lr = 0.0001 - reg = 0.0005 - dropout = 0.01 - bidirectional = True - params = { 'fm' : fmt, 'bs': batch_size, 'ne': num_epochs, 'lr': lr, 'rg': reg, 'do': dropout, 'bi': bidirectional } - - dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch) - model = LSTM(batch_size, 768, 64, num_layers=1, bidirectional=bidirectional) - loss_fn = nn.BCELoss() #BCELoss - optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=reg) - model.train() - - l = len(dataloader.dataset) - - for epoch in range(num_epochs): - total_acc = 0 - total_loss = 0 - tp = 0 - fp = 0 - fn = 0 - - for sentence_batch in tqdm(dataloader): - optimizer.zero_grad() - label_batch = sentence_batch.labels - #print("label_batch", label_batch.shape) - #print("tok_ids", sentence_batch.tok_ids.shape) - pred = model(sentence_batch.getBatchEncoding()) - #print("pred", pred.shape) - loss = loss_fn(pred, label_batch) - loss.backward() - optimizer.step() - pred_binary = (pred >= 0.5).float() - - for i in range(label_batch.size(0)): - for j in range(label_batch.size(1)): - if pred_binary[i,j] == 1.: - if label_batch[i,j] == 1.: - tp += 1 - else: fp += 1 - elif label_batch[i,j] == 1.: fn += 1 - #print("tp,fp,fn=",tp/label_batch.size(1),fp/label_batch.size(1),fn/label_batch.size(1)) - #nb_1 = pred_binary.sum() - #print("nb predicted 1=", nb_1) - sum_score = ((pred_binary == label_batch).float().sum().item()) / label_batch.size(1) - assert (sum_score <= batch_size) - total_acc += sum_score - total_loss += loss.item() #*label_batch.size(0) - - f1 = tp / (tp + (fp + fn) / 2) - print(f"Epoch {epoch} Accuracy {total_acc/l} Loss {total_loss/l} F1 {f1}") - - print('done training') - output_file = save_model(model, 'baseline', corpus, params) - print(f'model saved at {output_file}') + print(f'starting training of {corpus} in format {fmt}...') + data = generate_sentence_list(corpus, 'train', fmt) + + torch.manual_seed(1) + + #PARAMETERS + batch_size = 32 if fmt == 'conllu' else 4 + n_epochs = 10 + lr = 0.0001 + reg = 0.001 + dropout = 0.1 + n_layers = 1 + n_hidden = 64 + bidirectional = True + params = { 'fm' : fmt, 'nl': n_layers, 'nh': n_hidden, 'bs': batch_size, 'ne': n_epochs, 'lr': lr, 'rg': reg, 'do': dropout, 'bi': bidirectional } + + dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch) + model = LSTM(batch_size, 768, n_hidden, n_layers, bidirectional=bidirectional) + loss_fn = nn.BCELoss() + optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=reg) + model.train() + + l = len(dataloader.dataset) + + for epoch in range(n_epochs): + total_acc = 0 + total_loss = 0 + tp = 0 + fp = 0 + fn = 0 + + for sentence_batch in tqdm(dataloader): + optimizer.zero_grad() + label_batch = sentence_batch.labels + pred = model(sentence_batch.getBatchEncoding()) + loss = loss_fn(pred, label_batch) + loss.backward() + optimizer.step() + pred_binary = (pred >= 0.5).float() + + for i in range(label_batch.size(0)): + for j in range(label_batch.size(1)): + if pred_binary[i,j] == 1.: + if label_batch[i,j] == 1.: + tp += 1 + else: fp += 1 + elif label_batch[i,j] == 1.: fn += 1 + sum_score = ((pred_binary == label_batch).float().sum().item()) / label_batch.size(1) + assert (sum_score <= batch_size) + total_acc += sum_score + total_loss += loss.item() #*label_batch.size(0) + + f1 = tp / (tp + (fp + fn) / 2) + print(f"Epoch {epoch} Accuracy {total_acc/l} Loss {total_loss/l} F1 {f1}") + + print('done training') + output_file = save_model(model, 'baseline', corpus, params) + print(f'model saved at {output_file}') def save_model(model, model_type, corpus, params): - models_dir = 'saved_models' - if not os.path.isdir(models_dir): - os.system(f'mkdir {models_dir}') - corpus_dir = os.path.join(models_dir, corpus) - if not os.path.isdir(corpus_dir): - os.system(f'mkdir {corpus_dir}') - model_file = f"{model_type}_{corpus}_{params['fm']}_{params['bs']}_{params['ne']}_{params['lr']}_{params['rg']}_{params['do']}_{params['bi']}.pth" - output_file = os.path.join(corpus_dir, model_file) - if not os.path.isfile(output_file): - torch.save(model, output_file) - else: - print("Model already registers. Do you wish to overwrite? (Y/n)") - user = input() - if user == "" or user == "Y" or user == "y": - torch.save(model, output_file) - return output_file + models_dir = 'saved_models' + if not os.path.isdir(models_dir): + os.system(f'mkdir {models_dir}') + corpus_dir = os.path.join(models_dir, corpus) + if not os.path.isdir(corpus_dir): + os.system(f'mkdir {corpus_dir}') + model_file = f"{model_type}_{corpus}_{params['nl']}_{params['nh']}_{params['fm']}_{params['bs']}_{params['ne']}_{params['lr']}_{params['rg']}_{params['do']}_{params['bi']}.pth" + output_file = os.path.join(corpus_dir, model_file) + if not os.path.isfile(output_file): + torch.save(model, output_file) + else: + print("Model already exists. Do you wish to overwrite? (Y/n)") + user = input() + if user == "" or user == "Y" or user == "y": + torch.save(model, output_file) + return output_file def main(): - if len(sys.argv) < 2 or len(sys.argv) > 3: - print("usage: train_model_baseline.py <corpus> [<conllu/tok>]") - sys.exit() - corpus = sys.argv[1] - fmt = 'conllu' - if len(sys.argv) == 3: - fmt = sys.argv[2] - if fmt != 'conllu' and fmt != 'tok': - print("usage: train_model_baseline.py <corpus> [<conllu/tok>]") - sys.exit() - train(corpus, fmt) + parser = argparse.ArgumentParser(description='Train baseline model') + parser.add_argument('--corpus', required=True, help='corpus to train') + parser.add_argument('--format', default='conllu', help='tok or conllu') + params = parse.parse_args() + train(params.corpus, params.format) if __name__ == '__main__': - main() + main()