Skip to content
Snippets Groups Projects
Commit b0a2c22a authored by Alice Pain's avatar Alice Pain
Browse files

model

parent 708a2f75
Branches
No related tags found
No related merge requests found
"""Parse RST or SDRT corpus"""
import sys import sys
import os import os
import pandas as pd import pandas as pd
...@@ -6,58 +7,52 @@ import re ...@@ -6,58 +7,52 @@ import re
import numpy as np import numpy as np
def parse(corpus): def parse(corpus):
"""Turn corpus into a list of sentences then save it""" """Turn corpus into a list of sentences then save it"""
#print(f'parsing of {corpus} begins') print(f'parsing of {corpus} begins')
print(corpus) split = corpus.split("_")
split = corpus.split("_") corpus_dir = split[0]
corpus_dir = split[0] input_file = os.path.join("data/", corpus_dir, corpus)
input_file = os.path.join("data/", corpus_dir, corpus) output_dir = 'parsed_data'
output_dir = 'parsed_data' output_file = os.path.join(output_dir, f'parsed_{corpus}')
output_file = os.path.join(output_dir, f'parsed_{corpus}') if not os.path.isdir(output_dir):
if not os.path.isdir(output_dir): os.system(f"mkdir {output_dir}")
os.system(f"mkdir {output_dir}") if os.path.isfile(output_file + '.npz'):
if os.path.isfile(output_file): print(f'{corpus} already parsed. do you wish to overwrite? (Y/n)')
print(f'{corpus} already parsed') user = input()
sys.exit() if not (user == "" or user == "Y" or user == "y"):
#size = os.stat(corpus).st_size print('done')
#with open(corpus, "r") as f: sys.exit()
# data = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
column_names = ['tok_id','tok','1','2','3','4','5','6','7','label'] column_names = ['tok_id','tok','1','2','3','4','5','6','7','label']
df = pd.read_csv(input_file, names = column_names, skip_blank_lines=True, comment="#", sep="\t") if corpus == 'eng.rst.rstdt_train.conllu':
tok_ids = df['tok_id'].values df = pd.read_csv(input_file, names = column_names, skip_blank_lines=True, comment="#", sep="\t", engine='python', error_bad_lines=False, header=None)
toks = df['tok'].values else:
labels = df['label'].values df = pd.read_csv(input_file, names = column_names, skip_blank_lines=True, comment="#", sep="\t")
labels = np.where(labels == "BeginSeg=Yes", 1, 0) tok_ids = df['tok_id'].values
nb_segs = np.sum(labels) toks = df['tok'].values
print(nb_segs) labels = df['label'].values
if corpus.startswith('eng.rst.gum'):
np.savez_compressed(output_file, tok_ids = tok_ids, toks = toks, labels = labels) new_labels = np.zeros(labels.shape)
for i, label in enumerate(labels):
if False: if isinstance(label, str) and "BeginSeg=Yes" in label:
line = None new_labels[i] = 1
i = 0 else: new_labels[i] = 0
while line != b"\n": else:
line = data.readline() new_labels = np.where((labels == ("BeginSeg=Yes")), 1, 0) #labels == BeginSeg=Yes
if not line.startswith(b'#'): labels = new_labels
i += 1 nb_segs = np.sum(labels)
print(f"line {i}=", line)
print(i) np.savez_compressed(output_file, tok_ids = tok_ids, toks = toks, labels = labels)
m = re.search(br"#", data) #.decode('utf-8')) print(f'done parsing. data saved at {output_file}')
print(m)
i = data.find(b"#")
print(i)
#print('parsing done')
def main(): def main():
if len(sys.argv) < 2: if len(sys.argv) < 2:
print("usage: parse_corpus.py <corpus>") print("usage: parse_corpus.py <corpus>")
sys.exit() sys.exit()
corpora = sys.argv[1:] corpora = sys.argv[1:]
for corpus in corpora: for corpus in corpora:
parse(corpus) parse(corpus)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
"""Conllu training"""
import numpy as np import numpy as np
import sys import sys
import os import os
...@@ -12,176 +10,110 @@ from torch.utils.data import DataLoader ...@@ -12,176 +10,110 @@ from torch.utils.data import DataLoader
#from torchmetrics import F1Score #from torchmetrics import F1Score
#from torch.autograd import Variable #from torch.autograd import Variable
from tqdm import tqdm from tqdm import tqdm
from train_model import LSTM from train_model_baseline import LSTM, SentenceBatch, generate_sentence_list, toks_to_ids, make_labels, make_tok_types, make_tok_masks, collate_batch
bert = 'bert-base-multilingual-cased' bert = 'bert-base-multilingual-cased'
tokenizer = BertTokenizer.from_pretrained(bert) tokenizer = BertTokenizer.from_pretrained(bert)
#bert_embeddings = BertModel.from_pretrained(bert)
#class LSTM(nn.Module):
#
# def __init__(self, batch_size, input_size, hidden_size, num_layers=1, bidirectional=False):
# super().__init__()
# self.batch_size = batch_size
# self.hidden_size = hidden_size
# self.bert_embeddings = BertModel.from_pretrained(bert)
# self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
# d = 2 if bidirectional else 1
# self.hiddenToLabel = nn.Linear(d * hidden_size, 1)
# self.act = nn.Sigmoid()
# #self.hidden = self.init_hidden()
#
# #def init_hidden(self):
# # #return (Variable(torch.zeros(2, self.batch_size, self.hidden_size)), Variable(torch.zeros(2, self.batch_size, self.hidden_size)))
# # return torch.zeros(self.batch_size, self.hidden_size)
#
# def forward(self, batch):
# output = self.bert_embeddings(**batch)
# output768 = output.last_hidden_state
# output64, (last_hidden_state, last_cell_state) = self.lstm(output768)
# #output64, self.hidden = self.lstm(output768, self.hidden)
# #print("output64=", output64.shape)
# #print("last_hidden_state", last_hidden_state.shape)
# #print("last_cell_state", last_cell_state.shape)
# output1 = self.hiddenToLabel(output64)
# #print("output1=", output1.shape)
# return self.act(output1[:,:,0])
#
# #lstm_out, self.hidden = self.lstm(output, self.hidden)
class SentenceBatch():
def __init__(self, tok_ids, tok_types, tok_masks, labels):
self.tok_ids = pad_sequence(tok_ids, batch_first=True)
self.tok_types = pad_sequence(tok_types, batch_first=True)
self.tok_masks = pad_sequence(tok_masks, batch_first=True)
self.labels = pad_sequence(labels, batch_first=True)
def getBatchEncoding(self):
dico = { 'input_ids': self.tok_ids, 'token_type_ids': self.tok_types, 'attention_mask': self.tok_masks }
return BatchEncoding(dico)
def generate_sentence_list(corpus, test_set):
#move that part to parse_corpus.py
parsed_data = f"parsed_data/parsed_{corpus}_{test_set}.conllu.npz"
print(parsed_data)
if not os.path.isfile(parsed_data):
print("you must parse the corpus before testing it. please run parse_corpus.py")
sys.exit()
data = np.load(parsed_data, allow_pickle = True)
tok_ids, toks, labels = [data[f] for f in data.files]
#print(tok_ids.shape, toks.shape, labels.shape)
sentences = [] def test(corpus, model_path, test_set, fmt):
previous_i = 0 model = torch.load(model_path)
for index, tok_id in enumerate(tok_ids): print(f'Model:\t{model_path}\nEval:\t{corpus}_{test_set}\nFormat:\t{fmt}')
if tok_id == 1 and index > 0: data = generate_sentence_list(corpus, test_set, fmt)
if index - previous_i <= 510: batch_size = 32
sentences.append((toks[previous_i:index], labels[previous_i:index])) dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
previous_i = index
else: model.eval()
sep = previous_i + 510 loss_fn = nn.BCELoss()
sentences.append((toks[previous_i:sep], labels[previous_i:sep]))
if sep - previous_i > 510: l = len(dataloader.dataset)
print("still too long sentence...") errors = []
sys.exit()
sentences.append((toks[sep:index], labels[sep:index])) with torch.no_grad():
return sentences total_acc = 0
total_loss = 0
def toks_to_ids(sentence): tp = 0
#print("sentence=", sentence) fp = 0
tokens = ['[CLS]'] + list(sentence) + ['[SEP]'] fn = 0
#print("tokens=", tokens)
return torch.tensor(tokenizer.convert_tokens_to_ids(tokens)) #len(tokens) for sentence_batch in tqdm(dataloader):
#print("token_ids=", token_ids) label_batch = sentence_batch.labels
#print("label_batch", label_batch.shape)
pred = model(sentence_batch.getBatchEncoding())
#print("pred", pred.shape)
loss = loss_fn(pred, label_batch)
pred_binary = (pred >= 0.5).float()
for i in range(label_batch.size(0)):
for j in range(label_batch.size(1)):
if pred_binary[i,j] == 1.:
if label_batch[i,j] == 1.:
tp += 1
else: fp += 1
elif label_batch[i,j] == 1.: fn += 1
sum_score = ((pred_binary == label_batch).float().sum().item()) / label_batch.size(1)
assert (sum_score <= batch_size)
for i, sentence_id in enumerate(sentence_batch.sentence_ids):
if (pred_binary[i] != label_batch[i]).sum() > 0: #if there's at least one error
errors.append((sentence_id, pred_binary[i]))
total_acc += sum_score
total_loss += loss.item() #*label_batch.size(0)
def make_labels(sentence): precision = tp / (tp + fp)
zero = np.array([0]) recall = tp / (tp + fn)
add_two = np.concatenate((np.concatenate((zero, sentence)), zero)) #add label 0 for [CLS] and [SEP] f1 = 2 * (precision * recall) / (precision + recall)
return torch.from_numpy(add_two).float()
def make_tok_types(l): #print_errors(errors, data)
return torch.zeros(l, dtype=torch.int32)
def make_tok_masks(l): print(f"Acc\t{total_acc/l}\nLoss\t{total_loss/l}\nP\t{precision}\nR\t{recall}\nF1\t{f1}\n\n")
return torch.ones(l, dtype=torch.int32)
def collate_batch(batch): def print_segmentation(toks, labels):
token_batch, label_batch = [i for i, _ in batch], [j for _, j in batch] s = ""
#mappings = [make_mapping(sentence) for sentence in token_batch] for i, tok in enumerate(toks):
labels = [make_labels(sentence) for sentence in label_batch] if i+1 < len(labels) and labels[i+1] == 1:
tok_ids = [toks_to_ids(sentence) for sentence in token_batch] s += "| "
lengths = [len(toks) for toks in tok_ids] s += str(tok) + " "
tok_types = [make_tok_types(l) for l in lengths] print(s)
tok_masks = [make_tok_masks(l) for l in lengths]
return SentenceBatch(tok_ids, tok_types, tok_masks, labels)
def test(corpus, model, test_set):
print(f'starting testing on {test_set} set...')
data = generate_sentence_list(corpus, test_set)
batch_size = 32
dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
model.eval()
loss_fn = nn.BCELoss()
l = len(dataloader.dataset)
with torch.no_grad():
total_acc = 0
total_loss = 0
tp = 0
fp = 0
fn = 0
for sentence_batch in tqdm(dataloader):
label_batch = sentence_batch.labels
#print("label_batch", label_batch.shape)
pred = model(sentence_batch.getBatchEncoding())
#print("pred", pred.shape)
loss = loss_fn(pred, label_batch)
pred_binary = (pred >= 0.5).float()
for i in range(label_batch.size(0)):
for j in range(label_batch.size(1)):
if pred_binary[i,j] == 1.:
if label_batch[i,j] == 1.:
tp += 1
else: fp += 1
elif label_batch[i,j] == 1.: fn += 1
#print("tp,fp,fn=",tp/label_batch.size(1),fp/label_batch.size(1),fn/label_batch.size(1))
#nb_1 = pred_binary.sum()
#print("nb predicted 1=", nb_1)
sum_score = ((pred_binary == label_batch).float().sum().item()) / label_batch.size(1)
assert (sum_score <= batch_size)
total_acc += sum_score
total_loss += loss.item() #*label_batch.size(0)
f1 = tp / (tp + (fp + fn) / 2) def print_errors(errors, data, max_print=25):
print(f"Results on {test_set}: \nAccuracy {total_acc/l} \nLoss {total_loss/l} \nF1 {f1}")
print(f'Reporting {max_print} errors')
print('done testing') max_print = min(max_print, len(errors))
for sentence_id, pred in errors[:max_print]:
_, (toks, labels) = data[sentence_id]
print(f'Sentence {sentence_id}')
print('Predicted:')
print_segmentation(toks, pred)
print('True:')
labels = [0] + list(labels) + [0]
print_segmentation(toks, labels)
def main(): def main():
if len(sys.argv) < 2 or len(sys.argv) > 3: if len(sys.argv) < 2 or len(sys.argv) > 4:
print("usage: test_model.py <model> [<test/dev>]") print("usage: test_model.py <model> [<test_corpus>] [<test/dev/train>]")
sys.exit() sys.exit()
model_path = sys.argv[1] model_path = sys.argv[1]
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
print("model not found. please train the model first. please provide a relative path from discut dir.") print("model not found. please train the model first. please provide a relative path from discut dir.")
sys.exit() sys.exit()
test_set = 'test' test_set = 'test'
if len(sys.argv) == 3: if len(sys.argv) == 4:
if sys.argv[2] == 'dev': if sys.argv[3] == 'dev':
test_set = 'dev' test_set = 'dev'
elif sys.argv[2] != 'test': elif sys.arvg[3] == 'train':
print("usage: test_model.py <model> [<test/dev>]") test_set = 'train'
sys.exit() elif sys.argv[3] != 'test':
model = torch.load(model_path) print("usage: test_model.py <model> [<test_corpus>] [<test/dev/train>]")
model_split = model_path.split("/") sys.exit()
corpus = model_split[1] model_split = model_path.split("/")
test(corpus, model, test_set) corpus = model_split[1]
if len(sys.argv) >= 3:
corpus = sys.argv[2]
fmt = 'conllu'
params = model_split[2]
params_split = params.split("_")
if params_split[1] == 'tok':
fmt = 'tok'
test(corpus, model_path, test_set, fmt)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
"""Conllu training"""
import numpy as np
import sys
import os
from transformers import BertTokenizer, BertModel
from transformers.tokenization_utils_base import BatchEncoding
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
#from torchmetrics import F1Score
#from torch.autograd import Variable
from tqdm import tqdm
bert = 'bert-base-multilingual-cased'
tokenizer = BertTokenizer.from_pretrained(bert)
#bert_embeddings = BertModel.from_pretrained(bert)
class LSTM(nn.Module):
def __init__(self, batch_size, input_size, hidden_size, num_layers=1, bidirectional=False):
super().__init__()
self.batch_size = batch_size
self.hidden_size = hidden_size
self.bert_embeddings = BertModel.from_pretrained(bert)
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
d = 2 if bidirectional else 1
self.hiddenToLabel = nn.Linear(d * hidden_size, 1)
self.act = nn.Sigmoid()
#self.hidden = self.init_hidden()
#def init_hidden(self):
# #return (Variable(torch.zeros(2, self.batch_size, self.hidden_size)), Variable(torch.zeros(2, self.batch_size, self.hidden_size)))
# return torch.zeros(self.batch_size, self.hidden_size)
def forward(self, batch):
output = self.bert_embeddings(**batch)
output768 = output.last_hidden_state
output64, (last_hidden_state, last_cell_state) = self.lstm(output768)
#output64, self.hidden = self.lstm(output768, self.hidden)
#print("output64=", output64.shape)
#print("last_hidden_state", last_hidden_state.shape)
#print("last_cell_state", last_cell_state.shape)
output1 = self.hiddenToLabel(output64)
#print("output1=", output1.shape)
return self.act(output1[:,:,0])
#lstm_out, self.hidden = self.lstm(output, self.hidden)
class SentenceBatch():
def __init__(self, tok_ids, tok_types, tok_masks, labels):
self.tok_ids = pad_sequence(tok_ids, batch_first=True)
self.tok_types = pad_sequence(tok_types, batch_first=True)
self.tok_masks = pad_sequence(tok_masks, batch_first=True)
self.labels = pad_sequence(labels, batch_first=True)
def getBatchEncoding(self):
dico = { 'input_ids': self.tok_ids, 'token_type_ids': self.tok_types, 'attention_mask': self.tok_masks }
return BatchEncoding(dico)
def generate_sentence_list(corpus):
#move that part to parse_corpus.py
parsed_data = f"parsed_data/parsed_{corpus}_train.conllu.npz"
if not os.path.isfile(parsed_data):
print("you must parse the corpus before training it")
sys.exit()
data = np.load(parsed_data, allow_pickle = True)
tok_ids, toks, labels = [data[f] for f in data.files]
#print(tok_ids.shape, toks.shape, labels.shape)
sentences = []
previous_i = 0
for index, tok_id in enumerate(tok_ids):
if tok_id == 1 and index > 0:
if index - previous_i <= 510:
sentences.append((toks[previous_i:index], labels[previous_i:index]))
previous_i = index
else:
sep = previous_i + 510
sentences.append((toks[previous_i:sep], labels[previous_i:sep]))
if sep - previous_i > 510:
print("still too long sentence...")
sys.exit()
sentences.append((toks[sep:index], labels[sep:index]))
return sentences
def toks_to_ids(sentence):
#print("sentence=", sentence)
tokens = ['[CLS]'] + list(sentence) + ['[SEP]']
#print("tokens=", tokens)
return torch.tensor(tokenizer.convert_tokens_to_ids(tokens)) #len(tokens)
#print("token_ids=", token_ids)
def make_labels(sentence):
zero = np.array([0])
add_two = np.concatenate((np.concatenate((zero, sentence)), zero)) #add label 0 for [CLS] and [SEP]
return torch.from_numpy(add_two).float()
def make_tok_types(l):
return torch.zeros(l, dtype=torch.int32)
def make_tok_masks(l):
return torch.ones(l, dtype=torch.int32)
def collate_batch(batch):
token_batch, label_batch = [i for i, _ in batch], [j for _, j in batch]
#mappings = [make_mapping(sentence) for sentence in token_batch]
labels = [make_labels(sentence) for sentence in label_batch]
tok_ids = [toks_to_ids(sentence) for sentence in token_batch]
lengths = [len(toks) for toks in tok_ids]
tok_types = [make_tok_types(l) for l in lengths]
tok_masks = [make_tok_masks(l) for l in lengths]
return SentenceBatch(tok_ids, tok_types, tok_masks, labels)
def train(corpus):
print('starting training...')
data = generate_sentence_list(corpus)
batch_size = 32
dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
torch.manual_seed(1)
num_epochs = 10
lr = 0.0001
bidirectional = True
params = { 'bs': batch_size, 'ne': num_epochs, 'lr': lr, 'bi': bidirectional }
model = LSTM(batch_size, 768, 64, num_layers=1, bidirectional=bidirectional) #64
loss_fn = nn.BCELoss() #BCELoss
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
model.train()
l = len(dataloader.dataset)
for epoch in range(num_epochs):
total_acc = 0
total_loss = 0
tp = 0
fp = 0
fn = 0
for sentence_batch in tqdm(dataloader):
optimizer.zero_grad()
label_batch = sentence_batch.labels
#print("label_batch", label_batch.shape)
pred = model(sentence_batch.getBatchEncoding())
#print("pred", pred.shape)
loss = loss_fn(pred, label_batch)
loss.backward()
optimizer.step()
pred_binary = (pred >= 0.5).float()
for i in range(label_batch.size(0)):
for j in range(label_batch.size(1)):
if pred_binary[i,j] == 1.:
if label_batch[i,j] == 1.:
tp += 1
else: fp += 1
elif label_batch[i,j] == 1.: fn += 1
#print("tp,fp,fn=",tp/label_batch.size(1),fp/label_batch.size(1),fn/label_batch.size(1))
#nb_1 = pred_binary.sum()
#print("nb predicted 1=", nb_1)
sum_score = ((pred_binary == label_batch).float().sum().item()) / label_batch.size(1)
assert (sum_score <= batch_size)
total_acc += sum_score
total_loss += loss.item() #*label_batch.size(0)
f1 = tp / (tp + (fp + fn) / 2)
print(f"Epoch {epoch} Accuracy {total_acc/l} Loss {total_loss/l} F1 {f1}")
print('done training')
output_file = save_model(model, corpus, params)
print('model saved at {output_file}')
def save_model(model, corpus, params):
models_dir = 'saved_models'
if not os.path.isdir(models_dir):
os.system(f'mkdir {models_dir}')
corpus_dir = os.path.join(models_dir, corpus)
if not os.path.isdir(corpus_dir):
os.system(f'mkdir {corpus_dir}')
model_file = f"{corpus}_{params['bs']}_{params['ne']}_{params['lr']}_{params['bi']}.pth"
output_file = os.path.join(corpus_dir, model_file)
if not os.path.isfile(output_file):
torch.save(model, output_file)
else:
print("Model already registers. Do you wish to overwrite? (Y/n)")
user = input()
if user == "" or user == "Y" or user == "y":
torch.save(model, output_file)
return output_file
def main():
if len(sys.argv) < 2:
print("usage: train_model.py <corpus>")
sys.exit()
corpora = sys.argv[1:]
for corpus in corpora:
train(corpus)
if __name__ == '__main__':
main()
import numpy as np
import sys
import os
from transformers import BertTokenizer, BertModel
from transformers.tokenization_utils_base import BatchEncoding
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
#from torchmetrics import F1Score
#from torch.autograd import Variable
from tqdm import tqdm
bert = 'bert-base-multilingual-cased'
tokenizer = BertTokenizer.from_pretrained(bert)
#bert_embeddings = BertModel.from_pretrained(bert)
class LSTM(nn.Module):
def __init__(self, batch_size, input_size, hidden_size, num_layers=1, bidirectional=False):
super().__init__()
self.batch_size = batch_size
self.hidden_size = hidden_size
self.bert_embeddings = BertModel.from_pretrained(bert)
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
d = 2 if bidirectional else 1
self.hiddenToLabel = nn.Linear(d * hidden_size, 1)
self.act = nn.Sigmoid()
def forward(self, batch):
output = self.bert_embeddings(**batch)
output768 = output.last_hidden_state
output64, (last_hidden_state, last_cell_state) = self.lstm(output768)
#output64, self.hidden = self.lstm(output768, self.hidden)
#print("output64=", output64.shape)
#print("last_hidden_state", last_hidden_state.shape)
#print("last_cell_state", last_cell_state.shape)
output1 = self.hiddenToLabel(output64)
#print("output1=", output1.shape)
return self.act(output1[:,:,0])
#lstm_out, self.hidden = self.lstm(output, self.hidden)
class SentenceBatch():
def __init__(self, sentence_ids, tok_ids, tok_types, tok_masks, labels):
self.sentence_ids = sentence_ids
self.tok_ids = pad_sequence(tok_ids, batch_first=True)
self.tok_types = pad_sequence(tok_types, batch_first=True)
self.tok_masks = pad_sequence(tok_masks, batch_first=True)
self.labels = pad_sequence(labels, batch_first=True)
def getBatchEncoding(self):
dico = { 'input_ids': self.tok_ids, 'token_type_ids': self.tok_types, 'attention_mask': self.tok_masks }
return BatchEncoding(dico)
def generate_sentence_list(corpus, sset, fmt):
#move that part to parse_corpus.py
parsed_data = os.path.join("parsed_data", f"parsed_{corpus}_{sset}.{fmt}.npz")
if not os.path.isfile(parsed_data):
print("you must parse the corpus before training it")
sys.exit()
data = np.load(parsed_data, allow_pickle = True)
tok_ids, toks, labels = [data[f] for f in data.files]
sentences = []
if fmt == 'conllu':
previous_i = 0
for index, tok_id in enumerate(tok_ids):
if index > 0 and (tok_id == 1 or tok_id == '1' or (isinstance(tok_id, str) and tok_id.startswith('1-'))):
if index - previous_i <= 510:
sentences.append((toks[previous_i:index], labels[previous_i:index]))
previous_i = index
else:
sep = previous_i + 510
sentences.append((toks[previous_i:sep], labels[previous_i:sep]))
if sep - previous_i > 510:
print("still too long sentence...")
sys.exit()
sentences.append((toks[sep:index], labels[sep:index]))
else:
max_length = 510
nb_toks = len(tok_ids)
for i in range(0, nb_toks - max_length, max_length):
#add slices of 510 tokens
sentences.append((toks[i:(i + max_length)], labels[i:(i + max_length)]))
sentences.append((toks[-(nb_toks % max_length):], labels[-(nb_toks % max_length):]))
indexed_sentences = [0] * len(sentences)
for i, sentence in enumerate(sentences):
indexed_sentences[i] = (i, sentence)
return indexed_sentences
def toks_to_ids(sentence):
#print("sentence=", sentence)
tokens = ['[CLS]'] + list(sentence) + ['[SEP]']
#print("tokens=", tokens)
return torch.tensor(tokenizer.convert_tokens_to_ids(tokens)) #len(tokens)
#print("token_ids=", token_ids)
def make_labels(sentence):
zero = np.array([0])
add_two = np.concatenate((np.concatenate((zero, sentence)), zero)) #add label 0 for [CLS] and [SEP]
return torch.from_numpy(add_two).float()
def make_tok_types(l):
return torch.zeros(l, dtype=torch.int32)
def make_tok_masks(l):
return torch.ones(l, dtype=torch.int32)
def collate_batch(batch):
sentence_ids, token_batch, label_batch = [i for i, (_, _) in batch], [j for _, (j, _) in batch], [k for _, (_, k) in batch]
#mappings = [make_mapping(sentence) for sentence in token_batch]
labels = [make_labels(sentence) for sentence in label_batch]
tok_ids = [toks_to_ids(sentence) for sentence in token_batch]
lengths = [len(toks) for toks in tok_ids]
tok_types = [make_tok_types(l) for l in lengths]
tok_masks = [make_tok_masks(l) for l in lengths]
return SentenceBatch(sentence_ids, tok_ids, tok_types, tok_masks, labels)
def train(corpus, fmt):
print(f'starting training of {corpus} in format {fmt}...')
data = generate_sentence_list(corpus, 'train', fmt)
torch.manual_seed(1)
#PARAMETERS
batch_size = 32 if fmt == 'conllu' else 4
num_epochs = 10
lr = 0.0001
reg = 0.0005
dropout = 0.01
bidirectional = True
params = { 'fm' : fmt, 'bs': batch_size, 'ne': num_epochs, 'lr': lr, 'rg': reg, 'do': dropout, 'bi': bidirectional }
dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
model = LSTM(batch_size, 768, 64, num_layers=1, bidirectional=bidirectional)
loss_fn = nn.BCELoss() #BCELoss
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=reg)
model.train()
l = len(dataloader.dataset)
for epoch in range(num_epochs):
total_acc = 0
total_loss = 0
tp = 0
fp = 0
fn = 0
for sentence_batch in tqdm(dataloader):
optimizer.zero_grad()
label_batch = sentence_batch.labels
#print("label_batch", label_batch.shape)
#print("tok_ids", sentence_batch.tok_ids.shape)
pred = model(sentence_batch.getBatchEncoding())
#print("pred", pred.shape)
loss = loss_fn(pred, label_batch)
loss.backward()
optimizer.step()
pred_binary = (pred >= 0.5).float()
for i in range(label_batch.size(0)):
for j in range(label_batch.size(1)):
if pred_binary[i,j] == 1.:
if label_batch[i,j] == 1.:
tp += 1
else: fp += 1
elif label_batch[i,j] == 1.: fn += 1
#print("tp,fp,fn=",tp/label_batch.size(1),fp/label_batch.size(1),fn/label_batch.size(1))
#nb_1 = pred_binary.sum()
#print("nb predicted 1=", nb_1)
sum_score = ((pred_binary == label_batch).float().sum().item()) / label_batch.size(1)
assert (sum_score <= batch_size)
total_acc += sum_score
total_loss += loss.item() #*label_batch.size(0)
f1 = tp / (tp + (fp + fn) / 2)
print(f"Epoch {epoch} Accuracy {total_acc/l} Loss {total_loss/l} F1 {f1}")
print('done training')
output_file = save_model(model, corpus, params)
print(f'model saved at {output_file}')
def save_model(model, corpus, params):
models_dir = 'saved_models'
if not os.path.isdir(models_dir):
os.system(f'mkdir {models_dir}')
corpus_dir = os.path.join(models_dir, corpus)
if not os.path.isdir(corpus_dir):
os.system(f'mkdir {corpus_dir}')
model_file = f"{corpus}_{params['fm']}_{params['bs']}_{params['ne']}_{params['lr']}_{params['rg']}_{params['do']}_{params['bi']}.pth"
output_file = os.path.join(corpus_dir, model_file)
if not os.path.isfile(output_file):
torch.save(model, output_file)
else:
print("Model already registers. Do you wish to overwrite? (Y/n)")
user = input()
if user == "" or user == "Y" or user == "y":
torch.save(model, output_file)
return output_file
def main():
if len(sys.argv) < 2 or len(sys.argv) > 3:
print("usage: train_model_baseline.py <corpus> [<conllu/tok>]")
sys.exit()
corpus = sys.argv[1]
fmt = 'conllu'
if len(sys.argv) == 3:
fmt = sys.argv[2]
if fmt != 'conllu' and fmt != 'tok':
print("usage: train_model_baseline.py <corpus> [<conllu/tok>]")
sys.exit()
train(corpus, fmt)
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment