diff --git a/trytorch/check_parse.py b/trytorch/check_parse.py new file mode 100644 index 0000000000000000000000000000000000000000..ecb1fb3d2cb58657026c4655e38a23ed36278632 --- /dev/null +++ b/trytorch/check_parse.py @@ -0,0 +1,16 @@ +import sys +import numpy as np + +path = sys.argv[1] +data = np.load(path, allow_pickle=True) +tok_ids, toks, labels = [data[f] for f in data.files] + +for i, tok in enumerate(toks): + if isinstance(tok, str): + if len(tok) > 35: + print(tok) + else: + print(tok_ids[i]) + print(type(tok)) + + diff --git a/trytorch/check_rich.py b/trytorch/check_rich.py new file mode 100644 index 0000000000000000000000000000000000000000..6233fe54f6418a7afc24cfa88c9409a32c576de3 --- /dev/null +++ b/trytorch/check_rich.py @@ -0,0 +1,34 @@ +import sys +import numpy as np + +path = sys.argv[1] +data = np.load(path, allow_pickle=True) +tok_ids, toks, labels, upos, idheads, deprels = [data[f] for f in data.files] + +print(len(toks)) +print(len(upos)) + +labs_upos = np.unique(upos) +labs_idheads = np.unique(idheads) +labs_deprel = np.unique(deprels) +print(labs_upos) +print(labs_idheads) +print(labs_deprel) + +for i, tok in enumerate(toks): + if isinstance(tok, str): + if len(tok) > 35: + print(tok) + else: + print(tok_ids[i]) + print(type(tok)) + +for i, tok in enumerate(upos): + if isinstance(tok, str): + if len(tok) > 35: + print(tok) + else: + print(tok_ids[i]) + print(type(tok)) + + diff --git a/trytorch/parse_corpus.py b/trytorch/parse_corpus.py index 59f8e91168d4230e81628c8720a08094d416387c..f24fc233d161fda869c94bbdec93f1734f8924df 100644 --- a/trytorch/parse_corpus.py +++ b/trytorch/parse_corpus.py @@ -6,16 +6,22 @@ import mmap import re import numpy as np -def parse(corpus): +def parse(corpus, rich): """Turn corpus into a list of sentences then save it""" print(f'parsing of {corpus} begins') split = corpus.split("_") corpus_dir = split[0] input_file = os.path.join("data/", corpus_dir, corpus) output_dir = 'parsed_data' - output_file = os.path.join(output_dir, f'parsed_{corpus}') + if rich: + if not corpus.endswith('conllu'): + print('rich parsing only possible with conllu file') + sys.exit() + output_file = os.path.join(output_dir, f'parsedrich_{corpus}') + else: + output_file = os.path.join(output_dir, f'parsed_{corpus}') if not os.path.isdir(output_dir): - os.system(f"mkdir {output_dir}") + os.mkdir(output_dir) if os.path.isfile(output_file + '.npz'): print(f'{corpus} already parsed. do you wish to overwrite? (Y/n)') user = input() @@ -23,11 +29,11 @@ def parse(corpus): print('done') sys.exit() - column_names = ['tok_id','tok','1','2','3','4','5','6','7','label'] - if corpus == 'eng.rst.rstdt_train.conllu': - df = pd.read_csv(input_file, names = column_names, skip_blank_lines=True, comment="#", sep="\t", engine='python', error_bad_lines=False, header=None) - else: - df = pd.read_csv(input_file, names = column_names, skip_blank_lines=True, comment="#", sep="\t") + column_names = ['tok_id','tok','lemma','upos','xpos','gram','idhead','deprel','type','label'] + #if corpus.startswith('eng.rst.rstdt_train') or corpus.startswith('por.rst.cstn'): + # df = pd.read_csv(input_file, names = column_names, skip_blank_lines=True, sep="\t", quoting=3, comment='#') + #else: + df = pd.read_csv(input_file, names = column_names, skip_blank_lines=True, comment="#", sep="\t", quoting=3) tok_ids = df['tok_id'].values toks = df['tok'].values labels = df['label'].values @@ -38,21 +44,35 @@ def parse(corpus): new_labels[i] = 1 else: new_labels[i] = 0 else: - new_labels = np.where((labels == ("BeginSeg=Yes")), 1, 0) #labels == BeginSeg=Yes + new_labels = np.where((labels == "BeginSeg=Yes"), 1, 0) labels = new_labels nb_segs = np.sum(labels) - - np.savez_compressed(output_file, tok_ids = tok_ids, toks = toks, labels = labels) + + if rich: + upos = df['upos'].values + idheads = df['idhead'].values + deprels = df['deprel'].values + np.savez_compressed(output_file, tok_ids = tok_ids, toks = toks, labels = labels, upos = upos, idheads = idheads, deprels = deprels) + + else: + np.savez_compressed(output_file, tok_ids = tok_ids, toks = toks, labels = labels) print(f'done parsing. data saved at {output_file}') def main(): - if len(sys.argv) < 2: - print("usage: parse_corpus.py <corpus>") + if len(sys.argv) < 2 or len(sys.argv) > 3: + print("usage: parse_corpus.py <corpus> [<rich>]") + sys.exit() + corpus = sys.argv[1] + rich = False + #rich = True means we collect not only token and label but also dependency relationship for conllu files + if len(sys.argv) > 2: + if sys.argv[2] == 'rich': + rich = True + else: + print("usage: parse_corpus.py <corpus> [<rich>]") sys.exit() - corpora = sys.argv[1:] - for corpus in corpora: - parse(corpus) + parse(corpus, rich) if __name__ == '__main__': main() diff --git a/trytorch/parsed_data/parsed_fra.sdrt.annodis_train-mini.conllu.npz b/trytorch/parsed_data/parsed_fra.sdrt.annodis_train-mini.conllu.npz deleted file mode 100644 index 6821ad66bd027b64da05151676c2a8d414e9875f..0000000000000000000000000000000000000000 Binary files a/trytorch/parsed_data/parsed_fra.sdrt.annodis_train-mini.conllu.npz and /dev/null differ diff --git a/trytorch/test_model.py b/trytorch/test_model.py index 26bc6ff327183bf136dc864533a4959326157944..df60cfcbe008aa6ca28bf679a80a44e1a211b73c 100644 --- a/trytorch/test_model.py +++ b/trytorch/test_model.py @@ -1,27 +1,28 @@ import numpy as np import sys import os -from transformers import BertTokenizer, BertModel -from transformers.tokenization_utils_base import BatchEncoding +import argparse import torch from torch import nn -from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader -#from torchmetrics import F1Score -#from torch.autograd import Variable from tqdm import tqdm -from train_model_baseline import LSTM, SentenceBatch, generate_sentence_list, toks_to_ids, make_labels, make_tok_types, make_tok_masks, collate_batch +from train_model_baseline import LSTM, SentenceBatch, generate_sentence_list, collate_batch +from train_model_rich import RichLSTM, generate_rich_sentence_list, collate_rich_batch bert = 'bert-base-multilingual-cased' -tokenizer = BertTokenizer.from_pretrained(bert) -def test(corpus, model_path, test_set, fmt): +def test(model_path, model_type, corpus, test_set, fmt, show_errors): model = torch.load(model_path) - print(f'Model:\t{model_path}\nEval:\t{corpus}_{test_set}\nFormat:\t{fmt}') - data = generate_sentence_list(corpus, test_set, fmt) + print(f'Model:\t{model_path}\nType:\t{model_type}\nEval:\t{corpus}_{test_set}\nFormat:\t{fmt}') + batch_size = 32 - dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch) - + if model_type == 'baseline': + data = generate_sentence_list(corpus, test_set, fmt) + dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch) + else: + data, _, _ = generate_rich_sentence_list(corpus, test_set, fmt) + dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_rich_batch) + model.eval() loss_fn = nn.BCELoss() @@ -29,27 +30,25 @@ def test(corpus, model_path, test_set, fmt): errors = [] with torch.no_grad(): - total_acc = 0 - total_loss = 0 - tp = 0 - fp = 0 - fn = 0 + total_acc, total_loss = 0, 0 + tp, fp, fn = 0, 0, 0 for sentence_batch in tqdm(dataloader): label_batch = sentence_batch.labels - #print("label_batch", label_batch.shape) - pred = model(sentence_batch.getBatchEncoding()) - #print("pred", pred.shape) + if model_type == 'baseline': + pred = model(sentence_batch.getBatchEncoding()) + else: + pred = model(sentence_batch.getBatchEncoding(), sentence_batch.upos, sentence_batch.deprels, sentence_batch.dheads) loss = loss_fn(pred, label_batch) pred_binary = (pred >= 0.5).float() for i in range(label_batch.size(0)): for j in range(label_batch.size(1)): - if pred_binary[i,j] == 1.: - if label_batch[i,j] == 1.: - tp += 1 - else: fp += 1 - elif label_batch[i,j] == 1.: fn += 1 + if pred_binary[i,j] == 1.: + if label_batch[i,j] == 1.: + tp += 1 + else: fp += 1 + elif label_batch[i,j] == 1.: fn += 1 sum_score = ((pred_binary == label_batch).float().sum().item()) / label_batch.size(1) assert (sum_score <= batch_size) for i, sentence_id in enumerate(sentence_batch.sentence_ids): @@ -61,8 +60,9 @@ def test(corpus, model_path, test_set, fmt): precision = tp / (tp + fp) recall = tp / (tp + fn) f1 = 2 * (precision * recall) / (precision + recall) - - #print_errors(errors, data) + + if show_errors > 0: + print_errors(errors, data, max_print=show_errors) print(f"Acc\t{total_acc/l}\nLoss\t{total_loss/l}\nP\t{precision}\nR\t{recall}\nF1\t{f1}\n\n") @@ -72,10 +72,9 @@ def print_segmentation(toks, labels): if i+1 < len(labels) and labels[i+1] == 1: s += "| " s += str(tok) + " " - print(s) + print(s + '\n') def print_errors(errors, data, max_print=25): - print(f'Reporting {max_print} errors') max_print = min(max_print, len(errors)) for sentence_id, pred in errors[:max_print]: @@ -88,32 +87,23 @@ def print_errors(errors, data, max_print=25): print_segmentation(toks, labels) def main(): - if len(sys.argv) < 2 or len(sys.argv) > 4: - print("usage: test_model.py <model> [<test_corpus>] [<test/dev/train>]") - sys.exit() - model_path = sys.argv[1] - if not os.path.isfile(model_path): - print("model not found. please train the model first. please provide a relative path from discut dir.") - sys.exit() - test_set = 'test' - if len(sys.argv) == 4: - if sys.argv[3] == 'dev': - test_set = 'dev' - elif sys.arvg[3] == 'train': - test_set = 'train' - elif sys.argv[3] != 'test': - print("usage: test_model.py <model> [<test_corpus>] [<test/dev/train>]") - sys.exit() - model_split = model_path.split("/") - corpus = model_split[1] - if len(sys.argv) >= 3: - corpus = sys.argv[2] - fmt = 'conllu' - params = model_split[2] - params_split = params.split("_") - if params_split[1] == 'tok': - fmt = 'tok' - test(corpus, model_path, test_set, fmt) + parser = argparse.ArgumentParser(description='Test a model on a dataset') + parser.add_argument('--model', help='Path to .pth saved model') + parser.add_argument('--format', default='conllu', help='tok or conllu') + parser.add_argument('--type', default='baseline', help="baseline or rich model") + parser.add_argument('--corpus', help='corpus to test on') + parser.add_argument('--set', default='test', help='portion of the corpus to test on') + parser.add_argument('--errors', default=0, help='number of prediction errors to display on standard output') + + params = parser.parse_args() + + if not os.path.isfile(params.model): + print("model not found. please train the model first. please provide a relative path from discut dir.") + sys.exit() + if params.type == 'rich' and params.format == 'tok': + print('a rich model requires a .conllu file') + sys.exit() + test(params.model, params.type, params.corpus, params.set, params.format, params.errors) if __name__ == '__main__': main() diff --git a/trytorch/train_model_baseline.py b/trytorch/train_model_baseline.py index 4acbd84cf3809e3c764d78118548bfeb392c53dd..b6038b51e56698c134c87b1aa2a6fddaccad3066 100644 --- a/trytorch/train_model_baseline.py +++ b/trytorch/train_model_baseline.py @@ -12,7 +12,6 @@ from torch.utils.data import DataLoader from tqdm import tqdm bert = 'bert-base-multilingual-cased' -tokenizer = BertTokenizer.from_pretrained(bert) #bert_embeddings = BertModel.from_pretrained(bert) class LSTM(nn.Module): @@ -42,14 +41,16 @@ class LSTM(nn.Module): #lstm_out, self.hidden = self.lstm(output, self.hidden) class SentenceBatch(): - - def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels): + def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels, uposes = None, deprels = None, dheads = None): self.sentence_ids = sentence_ids self.tokens = tokens - self.tok_ids = pad_sequence(tok_ids, batch_first=True) + self.tok_ids = pad_sequence(tok_ids, batch_first=True) #bert token ids self.tok_types = pad_sequence(tok_types, batch_first=True) self.tok_masks = pad_sequence(tok_masks, batch_first=True) self.labels = pad_sequence(labels, batch_first=True) + self.uposes = pad_sequence(upos, batch_first=True) + self.dheads = pad_sequence(dheads, batch_first=True) + self.labels = pad_sequence(labels, batch_first=True) def getBatchEncoding(self): dico = { 'input_ids': self.tok_ids, 'token_type_ids': self.tok_types, 'attention_mask': self.tok_masks } @@ -102,6 +103,7 @@ def toks_to_ids(sentence): #print("sentence=", sentence) tokens = ['[CLS]'] + list(sentence) + ['[SEP]'] #print("tokens=", tokens) + tokenizer = BertTokenizer.from_pretrained(bert) return torch.tensor(tokenizer.convert_tokens_to_ids(tokens)) #len(tokens) #print("token_ids=", token_ids) @@ -189,10 +191,10 @@ def train(corpus, fmt): print(f"Epoch {epoch} Accuracy {total_acc/l} Loss {total_loss/l} F1 {f1}") print('done training') - output_file = save_model(model, corpus, params) + output_file = save_model(model, 'baseline', corpus, params) print(f'model saved at {output_file}') -def save_model(model, corpus, params): +def save_model(model, model_type, corpus, params): models_dir = 'saved_models' if not os.path.isdir(models_dir): @@ -200,7 +202,7 @@ def save_model(model, corpus, params): corpus_dir = os.path.join(models_dir, corpus) if not os.path.isdir(corpus_dir): os.system(f'mkdir {corpus_dir}') - model_file = f"{corpus}_{params['fm']}_{params['bs']}_{params['ne']}_{params['lr']}_{params['rg']}_{params['do']}_{params['bi']}.pth" + model_file = f"{model_type}_{corpus}_{params['fm']}_{params['bs']}_{params['ne']}_{params['lr']}_{params['rg']}_{params['do']}_{params['bi']}.pth" output_file = os.path.join(corpus_dir, model_file) if not os.path.isfile(output_file): torch.save(model, output_file) diff --git a/trytorch/train_model_rich.py b/trytorch/train_model_rich.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb72b33f2634cc2ae705f55f0be42bce7fa6bda --- /dev/null +++ b/trytorch/train_model_rich.py @@ -0,0 +1,233 @@ +import numpy as np +import sys +import os +from transformers import BertModel +from transformers.tokenization_utils_base import BatchEncoding +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader +from tqdm import tqdm +from train_model_baseline import SentenceBatch, add_cls_sep, toks_to_ids, make_labels, make_tok_types, make_tok_masks, save_model + +bert = 'bert-base-multilingual-cased' +#bert_embeddings = BertModel.from_pretrained(bert) + +class RichLSTM(nn.Module): + + def __init__(self, batch_size, hidden_size, n_upos, upos_emb_dim, n_deprel, deprel_emb_dim, num_layers=1, bidirectional=False): + super().__init__() + self.batch_size = batch_size + self.hidden_size = hidden_size + + #embedding dims + self.bert_emb_dim = 768 + self.upos_emb_dim = upos_emb_dim + self.deprel_emb_dim = deprel_emb_dim + + #embedding layers + self.bert_embeddings = BertModel.from_pretrained(bert) + self.upos_embeddings = nn.Embedding(n_upos + 1, self.upos_emb_dim, padding_idx=0) + self.deprel_embeddings = nn.Embedding(n_deprel + 1, self.deprel_emb_dim, padding_idx=0) + lstm_input_size = self.bert_emb_dim + self.upos_emb_dim + self.deprel_emb_dim + 1 + self.lstm = nn.LSTM(lstm_input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional) + d = 2 if bidirectional else 1 + self.hiddenToLabel = nn.Linear(d * hidden_size, 1) + self.act = nn.Sigmoid() + + #convenient values for slicing the vectors: [ bert (768) | upos (U) | deprel (R) | dhead (1) ] + self.bert_upos = self.bert_emb_dim + self.upos_emb_dim + self.bert_upos_deprel = self.bert_emb_dim + self.upos_emb_dim + self.deprel_emb_dim + + def forward(self, tok_batch, upos_batch, deprel_batch, dhead_batch): + # batch: [B x L], where + # B = batch_size + # L = max sentence length in batch + + bert_output = self.bert_embeddings(**tok_batch) + bert_output = bert_output.last_hidden_state + # bert_output: [B x L x 768] + upos_output = self.upos_embeddings(upos_batch) + # upos_output: [B x L x U] + deprel_output = self.deprel_embeddings(deprel_batch) + # deprel_output: [B x L x R] + dhead_output = dhead_batch[:,:,None] + # dhead_output: [B x L x 1] + #print("bert_output=", bert_output.shape) + #print("upos_output=", upos_output.shape) + #print("deprel_output", deprel_output.shape) + #print("dhead_output", dhead_output.shape) + + full_output = torch.cat((bert_output, upos_output, deprel_output, dhead_output), dim=2) + output64, (last_hidden_state, last_cell_state) = self.lstm(full_output) + output1 = self.hiddenToLabel(output64) + #print("output1=", output1.shape) + return self.act(output1[:,:,0]) + + #lstm_out, self.hidden = self.lstm(output, self.hidden) + +#class RichSentenceBatch(): +# +# def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, upos, deprels, dheads, labels): +# self.sentence_ids = sentence_ids +# self.tokens = tokens +# self.tok_ids = pad_sequence(tok_ids, batch_first=True) +# self.tok_types = pad_sequence(tok_types, batch_first=True) +# self.tok_masks = pad_sequence(tok_masks, batch_first=True) +# self.upos = pad_sequence(upos, batch_first=True) +# self.deprels = pad_sequence(deprels, batch_first=True) #LOG SCALE +# self.dheads = pad_sequence(dheads, batch_first=True) +# self.labels = pad_sequence(labels, batch_first=True) +# +# def getBatchEncoding(self): +# dico = { 'input_ids': self.tok_ids, 'token_type_ids': self.tok_types, 'attention_mask': self.tok_masks } +# return BatchEncoding(dico) + +def cat_to_id(instance, cat): + return np.where(cat == instance)[0][0] + 1 + +def dh_list(tok_ids, idheads, s, e): + l = e - s + ret = [0] * l + new_tok_ids = tok_ids[s:e] + new_idheads = idheads[s:e] + for i, idhead in enumerate(new_idheads): + if idhead == 0: + ret[i] = 0 + else: + ret[i] = np.log(abs(idhead - new_tok_ids[i])) + return ret + +def generate_rich_sentence_list(corpus, sset, fmt): + parsed_data = os.path.join("parsed_data", f"parsedrich_{corpus}_{sset}.{fmt}.npz") + #print("PATH", parsed_data) + if not os.path.isfile(parsed_data): + print("you must parse the corpus before training it") + sys.exit() + data = np.load(parsed_data, allow_pickle = True) + tok_ids, toks, labels, upos, idheads, deprels = [data[f] for f in data.files] + + unique_upos = np.unique(upos) + unique_deprels = np.unique(deprels) + nb_upos = len(unique_upos) + nb_deprels = len(unique_deprels) + + sentences = [] + + previous_i = 0 + for index, tok_id in enumerate(tok_ids): + if index > 0 and (tok_id == 1 or tok_id == '1' or (isinstance(tok_id, str) and tok_id.startswith('1-'))): + if index - previous_i <= 510: + sentences.append((toks[previous_i:index], [cat_to_id(u, unique_upos) for u in upos[previous_i:index]], [cat_to_id(d, unique_deprels) for d in deprels[previous_i:index]], dh_list(tok_ids, idheads, previous_i, index), labels[previous_i:index])) + previous_i = index + else: + sep = previous_i + 510 + sentences.append((toks[previous_i:sep], labels[previous_i:sep])) + if sep - previous_i > 510: + print("still too long sentence...") + sys.exit() + sentences.append((toks[sep:index], [cat_to_id(u, unique_upos) for u in upos[sep:index]], [cat_to_id(d, unique_deprels) for d in deprels[sep:index]], dh_list(tok_ids, idheads, sep, index), labels[sep:index])) + + indexed_sentences = [0] * len(sentences) + for i, sentence in enumerate(sentences): + indexed_sentences[i] = (i, sentence) + + return indexed_sentences, nb_upos, nb_deprels + +def int_add_zeros(sentence): + zero = torch.zeros(1).int() + return torch.cat((zero, torch.tensor(sentence).int(), zero)) + +def add_zeros(sentence): + zero = torch.zeros(1) + return torch.cat((zero, torch.tensor(sentence), zero)) + +def collate_rich_batch(batch): + sentence_ids, token_batch, upos_batch, idhead_batch, deprel_batch, label_batch = [i for i, (_, _, _, _, _) in batch], [j for _, (j, _, _, _, _) in batch], [k for _, (_, k, _, _, _) in batch], [l for _, (_, _, l, _, _) in batch], [m for _, (_, _, _, m, _) in batch], [n for _, (_, _, _, _, n) in batch] + labels = [make_labels(sentence) for sentence in label_batch] + tokens = [add_cls_sep(sentence) for sentence in token_batch] + tok_ids = [toks_to_ids(sentence) for sentence in token_batch] + lengths = [len(toks) for toks in tok_ids] + tok_types = [make_tok_types(l) for l in lengths] + tok_masks = [make_tok_masks(l) for l in lengths] + uposes = [int_add_zeros(sentence) for sentence in upos_batch] + deprels = [int_add_zeros(sentence) for sentence in deprel_batch] + dheads = [add_zeros(sentence) for sentence in idhead_batch] + + return SentenceBatch(sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels, uposes = uposes, deprels = deprels, dheads = dheads) + +def train(corpus): + print(f'starting rich training of {corpus}...') + data, nb_upos, nb_deprels = generate_rich_sentence_list(corpus, 'train', 'conllu') + upos_emb_dim = int(np.sqrt(nb_upos)) + deprel_emb_dim = int(np.sqrt(nb_deprels)) + + torch.manual_seed(1) + + #PARAMETERS + batch_size = 32 + num_epochs = 5 + lr = 0.0001 + reg = 0.001 + dropout = 0.01 + bidirectional = True + params = { 'fm' : 'conllu', 'bs': batch_size, 'ne': num_epochs, 'lr': lr, 'rg': reg, 'do': dropout, 'bi': bidirectional } + + dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_rich_batch) + model = RichLSTM(batch_size, 64, nb_upos, upos_emb_dim, nb_deprels, deprel_emb_dim, num_layers=1, bidirectional=bidirectional) + loss_fn = nn.BCELoss() + optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=reg) + model.train() + + l = len(dataloader.dataset) + + for epoch in range(num_epochs): + total_acc = 0 + total_loss = 0 + tp = 0 + fp = 0 + fn = 0 + + for sentence_batch in tqdm(dataloader): + optimizer.zero_grad() + label_batch = sentence_batch.labels + #print("label_batch", label_batch.shape) + #print("tok_ids", sentence_batch.tok_ids.shape) + pred = model(sentence_batch.getBatchEncoding(), sentence_batch.uposes, sentence_batch.deprels, sentence_batch.dheads) + #print("pred", pred.shape) + loss = loss_fn(pred, label_batch) + loss.backward() + optimizer.step() + pred_binary = (pred >= 0.5).float() + + for i in range(label_batch.size(0)): + for j in range(label_batch.size(1)): + if pred_binary[i,j] == 1.: + if label_batch[i,j] == 1.: + tp += 1 + else: fp += 1 + elif label_batch[i,j] == 1.: fn += 1 + #print("tp,fp,fn=",tp/label_batch.size(1),fp/label_batch.size(1),fn/label_batch.size(1)) + #nb_1 = pred_binary.sum() + #print("nb predicted 1=", nb_1) + sum_score = ((pred_binary == label_batch).float().sum().item()) / label_batch.size(1) + assert (sum_score <= batch_size) + total_acc += sum_score + total_loss += loss.item() #*label_batch.size(0) + + f1 = tp / (tp + (fp + fn) / 2) + print(f"Epoch {epoch} Accuracy {total_acc/l} Loss {total_loss/l} F1 {f1}") + + print('done training') + output_file = save_model(model, 'rich', corpus, params) + print(f'model saved at {output_file}') + +def main(): + if len(sys.argv) != 2: + print("usage: train_model_rich.py <corpus>") + sys.exit() + corpus = sys.argv[1] + train(corpus) + +if __name__ == '__main__': + main()