Skip to content
Snippets Groups Projects
Commit 78235faf authored by Alice Pain's avatar Alice Pain
Browse files

minor

parent f6e16e8f
No related branches found
No related tags found
No related merge requests found
import numpy as np
import sys
import os
import argparse
from transformers import BertTokenizer, BertModel
from transformers.tokenization_utils_base import BatchEncoding
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
#from torchmetrics import F1Score
#from torch.autograd import Variable
from tqdm import tqdm
bert = 'bert-base-multilingual-cased'
#bert_embeddings = BertModel.from_pretrained(bert)
tokenizer = BertTokenizer.from_pretrained(bert)
class LSTM(nn.Module):
def __init__(self, batch_size, input_size, hidden_size, num_layers=1, bidirectional=False):
super().__init__()
self.batch_size = batch_size
self.hidden_size = hidden_size
self.bert_embeddings = BertModel.from_pretrained(bert)
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
d = 2 if bidirectional else 1
self.hiddenToLabel = nn.Linear(d * hidden_size, 1)
self.act = nn.Sigmoid()
def forward(self, batch):
output = self.bert_embeddings(**batch)
output768 = output.last_hidden_state
output64, (last_hidden_state, last_cell_state) = self.lstm(output768)
#output64, self.hidden = self.lstm(output768, self.hidden)
#print("output64=", output64.shape)
#print("last_hidden_state", last_hidden_state.shape)
#print("last_cell_state", last_cell_state.shape)
output1 = self.hiddenToLabel(output64)
#print("output1=", output1.shape)
return self.act(output1[:,:,0])
#lstm_out, self.hidden = self.lstm(output, self.hidden)
def __init__(self, batch_size, input_size, hidden_size, n_layers=1, bidirectional=False):
super().__init__()
self.batch_size = batch_size
self.hidden_size = hidden_size
self.bert_embeddings = BertModel.from_pretrained(bert)
self.lstm = nn.LSTM(input_size, hidden_size, n_layers, batch_first=True, bidirectional=bidirectional)
d = 2 if bidirectional else 1
self.hiddenToLabel = nn.Linear(d * hidden_size, 1)
self.act = nn.Sigmoid()
def forward(self, batch):
output = self.bert_embeddings(**batch)
output768 = output.last_hidden_state
output64, (last_hidden_state, last_cell_state) = self.lstm(output768)
#output64, self.hidden = self.lstm(output768, self.hidden)
#print("output64=", output64.shape)
#print("last_hidden_state", last_hidden_state.shape)
#print("last_cell_state", last_cell_state.shape)
output1 = self.hiddenToLabel(output64)
#print("output1=", output1.shape)
return self.act(output1[:,:,0])
#lstm_out, self.hidden = self.lstm(output, self.hidden)
class SentenceBatch():
def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels, uposes = None, deprels = None, dheads = None):
self.sentence_ids = sentence_ids
self.tokens = tokens
self.tok_ids = pad_sequence(tok_ids, batch_first=True) #bert token ids
self.tok_types = pad_sequence(tok_types, batch_first=True)
self.tok_masks = pad_sequence(tok_masks, batch_first=True)
self.labels = pad_sequence(labels, batch_first=True)
if uposes is not None:
self.uposes = pad_sequence(uposes, batch_first=True)
else: self.uposes = None
if deprels is not None:
self.deprels = pad_sequence(deprels, batch_first=True)
else: self.deprels = None
if dheads is not None:
self.dheads = pad_sequence(dheads, batch_first=True)
else: self.dheads = None
self.labels = pad_sequence(labels, batch_first=True)
def getBatchEncoding(self):
dico = { 'input_ids': self.tok_ids, 'token_type_ids': self.tok_types, 'attention_mask': self.tok_masks }
return BatchEncoding(dico)
def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels, uposes = None, deprels = None, dheads = None):
self.sentence_ids = sentence_ids
self.tokens = tokens
self.tok_ids = pad_sequence(tok_ids, batch_first=True) #bert token ids
self.tok_types = pad_sequence(tok_types, batch_first=True)
self.tok_masks = pad_sequence(tok_masks, batch_first=True)
self.labels = pad_sequence(labels, batch_first=True)
if uposes is not None:
self.uposes = pad_sequence(uposes, batch_first=True)
else: self.uposes = None
if deprels is not None:
self.deprels = pad_sequence(deprels, batch_first=True)
else: self.deprels = None
if dheads is not None:
self.dheads = pad_sequence(dheads, batch_first=True)
else: self.dheads = None
self.labels = pad_sequence(labels, batch_first=True)
def getBatchEncoding(self):
dico = { 'input_ids': self.tok_ids, 'token_type_ids': self.tok_types, 'attention_mask': self.tok_masks }
return BatchEncoding(dico)
def generate_sentence_list(corpus, sset, fmt):
#move that part to parse_corpus.py
parsed_data = os.path.join("parsed_data", f"parsed_{corpus}_{sset}.{fmt}.npz")
#print("PATH", parsed_data)
if not os.path.isfile(parsed_data):
print("you must parse the corpus before training it")
sys.exit()
data = np.load(parsed_data, allow_pickle = True)
tok_ids, toks, labels = [data[f] for f in data.files]
sentences = []
if fmt == 'conllu':
previous_i = 0
for index, tok_id in enumerate(tok_ids):
if index > 0 and (tok_id == 1 or tok_id == '1' or (isinstance(tok_id, str) and tok_id.startswith('1-'))):
if index - previous_i <= 510:
sentences.append((toks[previous_i:index], labels[previous_i:index]))
previous_i = index
else:
sep = previous_i + 510
sentences.append((toks[previous_i:sep], labels[previous_i:sep]))
if sep - previous_i > 510:
print("still too long sentence...")
sys.exit()
sentences.append((toks[sep:index], labels[sep:index]))
else:
max_length = 510
nb_toks = len(tok_ids)
for i in range(0, nb_toks - max_length, max_length):
#add slices of 510 tokens
sentences.append((toks[i:(i + max_length)], labels[i:(i + max_length)]))
sentences.append((toks[-(nb_toks % max_length):], labels[-(nb_toks % max_length):]))
indexed_sentences = [0] * len(sentences)
for i, sentence in enumerate(sentences):
indexed_sentences[i] = (i, sentence)
return indexed_sentences
#move that part to parse_corpus.py
parsed_data = os.path.join("parsed_data", f"parsed_{corpus}_{sset}.{fmt}.npz")
#print("PATH", parsed_data)
if not os.path.isfile(parsed_data):
print("you must parse the corpus before training it")
sys.exit()
data = np.load(parsed_data, allow_pickle = True)
tok_ids, toks, labels = [data[f] for f in data.files]
sentences = []
if fmt == 'conllu':
previous_i = 0
for index, tok_id in enumerate(tok_ids):
if index > 0 and (tok_id == 1 or tok_id == '1' or (isinstance(tok_id, str) and tok_id.startswith('1-'))):
if index - previous_i <= 510:
sentences.append((toks[previous_i:index], labels[previous_i:index]))
previous_i = index
else:
sep = previous_i + 510
sentences.append((toks[previous_i:sep], labels[previous_i:sep]))
if sep - previous_i > 510:
print("still too long sentence...")
sys.exit()
sentences.append((toks[sep:index], labels[sep:index]))
else:
max_length = 510
nb_toks = len(tok_ids)
for i in range(0, nb_toks - max_length, max_length):
#add slices of 510 tokens
sentences.append((toks[i:(i + max_length)], labels[i:(i + max_length)]))
sentences.append((toks[-(nb_toks % max_length):], labels[-(nb_toks % max_length):]))
indexed_sentences = [0] * len(sentences)
for i, sentence in enumerate(sentences):
indexed_sentences[i] = (i, sentence)
return indexed_sentences
def add_cls_sep(sentence):
return ['[CLS]'] + list(sentence) + ['[SEP]']
return ['[CLS]'] + list(sentence) + ['[SEP]']
def toks_to_ids(sentence):
#print("sentence=", sentence)
tokens = ['[CLS]'] + list(sentence) + ['[SEP]']
#print("tokens=", tokens)
tokenizer = BertTokenizer.from_pretrained(bert)
return torch.tensor(tokenizer.convert_tokens_to_ids(tokens)) #len(tokens)
#print("token_ids=", token_ids)
#print("sentence=", sentence)
tokens = ['[CLS]'] + list(sentence) + ['[SEP]']
#print("tokens=", tokens)
return torch.tensor(tokenizer.convert_tokens_to_ids(tokens)) #len(tokens)
#print("token_ids=", token_ids)
def make_labels(sentence):
zero = np.array([0])
add_two = np.concatenate((np.concatenate((zero, sentence)), zero)) #add label 0 for [CLS] and [SEP]
return torch.from_numpy(add_two).float()
zero = np.array([0])
add_two = np.concatenate((np.concatenate((zero, sentence)), zero)) #add label 0 for [CLS] and [SEP]
return torch.from_numpy(add_two).float()
def make_tok_types(l):
return torch.zeros(l, dtype=torch.int32)
return torch.zeros(l, dtype=torch.int32)
def make_tok_masks(l):
return torch.ones(l, dtype=torch.int32)
return torch.ones(l, dtype=torch.int32)
def collate_batch(batch):
sentence_ids, token_batch, label_batch = [i for i, (_, _) in batch], [j for _, (j, _) in batch], [k for _, (_, k) in batch]
#mappings = [make_mapping(sentence) for sentence in token_batch]
labels = [make_labels(sentence) for sentence in label_batch]
tokens = [add_cls_sep(sentence) for sentence in token_batch]
tok_ids = [toks_to_ids(sentence) for sentence in token_batch]
lengths = [len(toks) for toks in tok_ids]
tok_types = [make_tok_types(l) for l in lengths]
tok_masks = [make_tok_masks(l) for l in lengths]
return SentenceBatch(sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels)
sentence_ids, token_batch, label_batch = [0] * len(batch), [0] * len(batch), [0] * len(batch)
for i, (ids, (toks, labs)) in enumerate(batch):
sentence_ids[i] = ids
token_batch[i] = toks
label_batch[i] = labs
labels = [make_labels(sentence) for sentence in label_batch]
tokens = [add_cls_sep(sentence) for sentence in token_batch]
tok_ids = [torch.tensor(tokenizer.convert_tokens_to_ids(sentence)) for sentence in tokens]
lengths = [len(toks) for toks in tok_ids]
tok_types = [make_tok_types(l) for l in lengths]
tok_masks = [make_tok_masks(l) for l in lengths]
return SentenceBatch(sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels)
def train(corpus, fmt):
print(f'starting training of {corpus} in format {fmt}...')
data = generate_sentence_list(corpus, 'train', fmt)
torch.manual_seed(1)
#PARAMETERS
batch_size = 32 if fmt == 'conllu' else 4
num_epochs = 10
lr = 0.0001
reg = 0.0005
dropout = 0.01
bidirectional = True
params = { 'fm' : fmt, 'bs': batch_size, 'ne': num_epochs, 'lr': lr, 'rg': reg, 'do': dropout, 'bi': bidirectional }
dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
model = LSTM(batch_size, 768, 64, num_layers=1, bidirectional=bidirectional)
loss_fn = nn.BCELoss() #BCELoss
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=reg)
model.train()
l = len(dataloader.dataset)
for epoch in range(num_epochs):
total_acc = 0
total_loss = 0
tp = 0
fp = 0
fn = 0
for sentence_batch in tqdm(dataloader):
optimizer.zero_grad()
label_batch = sentence_batch.labels
#print("label_batch", label_batch.shape)
#print("tok_ids", sentence_batch.tok_ids.shape)
pred = model(sentence_batch.getBatchEncoding())
#print("pred", pred.shape)
loss = loss_fn(pred, label_batch)
loss.backward()
optimizer.step()
pred_binary = (pred >= 0.5).float()
for i in range(label_batch.size(0)):
for j in range(label_batch.size(1)):
if pred_binary[i,j] == 1.:
if label_batch[i,j] == 1.:
tp += 1
else: fp += 1
elif label_batch[i,j] == 1.: fn += 1
#print("tp,fp,fn=",tp/label_batch.size(1),fp/label_batch.size(1),fn/label_batch.size(1))
#nb_1 = pred_binary.sum()
#print("nb predicted 1=", nb_1)
sum_score = ((pred_binary == label_batch).float().sum().item()) / label_batch.size(1)
assert (sum_score <= batch_size)
total_acc += sum_score
total_loss += loss.item() #*label_batch.size(0)
f1 = tp / (tp + (fp + fn) / 2)
print(f"Epoch {epoch} Accuracy {total_acc/l} Loss {total_loss/l} F1 {f1}")
print('done training')
output_file = save_model(model, 'baseline', corpus, params)
print(f'model saved at {output_file}')
print(f'starting training of {corpus} in format {fmt}...')
data = generate_sentence_list(corpus, 'train', fmt)
torch.manual_seed(1)
#PARAMETERS
batch_size = 32 if fmt == 'conllu' else 4
n_epochs = 10
lr = 0.0001
reg = 0.001
dropout = 0.1
n_layers = 1
n_hidden = 64
bidirectional = True
params = { 'fm' : fmt, 'nl': n_layers, 'nh': n_hidden, 'bs': batch_size, 'ne': n_epochs, 'lr': lr, 'rg': reg, 'do': dropout, 'bi': bidirectional }
dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
model = LSTM(batch_size, 768, n_hidden, n_layers, bidirectional=bidirectional)
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=reg)
model.train()
l = len(dataloader.dataset)
for epoch in range(n_epochs):
total_acc = 0
total_loss = 0
tp = 0
fp = 0
fn = 0
for sentence_batch in tqdm(dataloader):
optimizer.zero_grad()
label_batch = sentence_batch.labels
pred = model(sentence_batch.getBatchEncoding())
loss = loss_fn(pred, label_batch)
loss.backward()
optimizer.step()
pred_binary = (pred >= 0.5).float()
for i in range(label_batch.size(0)):
for j in range(label_batch.size(1)):
if pred_binary[i,j] == 1.:
if label_batch[i,j] == 1.:
tp += 1
else: fp += 1
elif label_batch[i,j] == 1.: fn += 1
sum_score = ((pred_binary == label_batch).float().sum().item()) / label_batch.size(1)
assert (sum_score <= batch_size)
total_acc += sum_score
total_loss += loss.item() #*label_batch.size(0)
f1 = tp / (tp + (fp + fn) / 2)
print(f"Epoch {epoch} Accuracy {total_acc/l} Loss {total_loss/l} F1 {f1}")
print('done training')
output_file = save_model(model, 'baseline', corpus, params)
print(f'model saved at {output_file}')
def save_model(model, model_type, corpus, params):
models_dir = 'saved_models'
if not os.path.isdir(models_dir):
os.system(f'mkdir {models_dir}')
corpus_dir = os.path.join(models_dir, corpus)
if not os.path.isdir(corpus_dir):
os.system(f'mkdir {corpus_dir}')
model_file = f"{model_type}_{corpus}_{params['fm']}_{params['bs']}_{params['ne']}_{params['lr']}_{params['rg']}_{params['do']}_{params['bi']}.pth"
output_file = os.path.join(corpus_dir, model_file)
if not os.path.isfile(output_file):
torch.save(model, output_file)
else:
print("Model already registers. Do you wish to overwrite? (Y/n)")
user = input()
if user == "" or user == "Y" or user == "y":
torch.save(model, output_file)
return output_file
models_dir = 'saved_models'
if not os.path.isdir(models_dir):
os.system(f'mkdir {models_dir}')
corpus_dir = os.path.join(models_dir, corpus)
if not os.path.isdir(corpus_dir):
os.system(f'mkdir {corpus_dir}')
model_file = f"{model_type}_{corpus}_{params['nl']}_{params['nh']}_{params['fm']}_{params['bs']}_{params['ne']}_{params['lr']}_{params['rg']}_{params['do']}_{params['bi']}.pth"
output_file = os.path.join(corpus_dir, model_file)
if not os.path.isfile(output_file):
torch.save(model, output_file)
else:
print("Model already exists. Do you wish to overwrite? (Y/n)")
user = input()
if user == "" or user == "Y" or user == "y":
torch.save(model, output_file)
return output_file
def main():
if len(sys.argv) < 2 or len(sys.argv) > 3:
print("usage: train_model_baseline.py <corpus> [<conllu/tok>]")
sys.exit()
corpus = sys.argv[1]
fmt = 'conllu'
if len(sys.argv) == 3:
fmt = sys.argv[2]
if fmt != 'conllu' and fmt != 'tok':
print("usage: train_model_baseline.py <corpus> [<conllu/tok>]")
sys.exit()
train(corpus, fmt)
parser = argparse.ArgumentParser(description='Train baseline model')
parser.add_argument('--corpus', required=True, help='corpus to train')
parser.add_argument('--format', default='conllu', help='tok or conllu')
params = parse.parse_args()
train(params.corpus, params.format)
if __name__ == '__main__':
main()
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment