diff --git a/trytorch/check_parse.py b/trytorch/check_parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecb1fb3d2cb58657026c4655e38a23ed36278632
--- /dev/null
+++ b/trytorch/check_parse.py
@@ -0,0 +1,16 @@
+import sys
+import numpy as np
+
+path = sys.argv[1]
+data = np.load(path, allow_pickle=True)
+tok_ids, toks, labels = [data[f] for f in data.files]
+
+for i, tok in enumerate(toks):
+    if isinstance(tok, str):
+        if len(tok) > 35:
+            print(tok)
+    else:
+        print(tok_ids[i])
+        print(type(tok))
+
+
diff --git a/trytorch/check_rich.py b/trytorch/check_rich.py
new file mode 100644
index 0000000000000000000000000000000000000000..6233fe54f6418a7afc24cfa88c9409a32c576de3
--- /dev/null
+++ b/trytorch/check_rich.py
@@ -0,0 +1,34 @@
+import sys
+import numpy as np
+
+path = sys.argv[1]
+data = np.load(path, allow_pickle=True)
+tok_ids, toks, labels, upos, idheads, deprels = [data[f] for f in data.files]
+
+print(len(toks))
+print(len(upos))
+
+labs_upos = np.unique(upos)
+labs_idheads = np.unique(idheads)
+labs_deprel = np.unique(deprels)
+print(labs_upos)
+print(labs_idheads)
+print(labs_deprel)
+
+for i, tok in enumerate(toks):
+    if isinstance(tok, str):
+        if len(tok) > 35:
+            print(tok)
+    else:
+        print(tok_ids[i])
+        print(type(tok))
+
+for i, tok in enumerate(upos):
+    if isinstance(tok, str):
+        if len(tok) > 35:
+            print(tok)
+    else:
+        print(tok_ids[i])
+        print(type(tok))
+
+
diff --git a/trytorch/parse_corpus.py b/trytorch/parse_corpus.py
index 59f8e91168d4230e81628c8720a08094d416387c..f24fc233d161fda869c94bbdec93f1734f8924df 100644
--- a/trytorch/parse_corpus.py
+++ b/trytorch/parse_corpus.py
@@ -6,16 +6,22 @@ import mmap
 import re
 import numpy as np
 
-def parse(corpus):
+def parse(corpus, rich):
         """Turn corpus into a list of sentences then save it"""
         print(f'parsing of {corpus} begins')
         split = corpus.split("_")
         corpus_dir = split[0]
         input_file = os.path.join("data/", corpus_dir, corpus)
         output_dir = 'parsed_data'
-        output_file = os.path.join(output_dir, f'parsed_{corpus}')
+        if rich:
+            if not corpus.endswith('conllu'):
+                print('rich parsing only possible with conllu file')
+                sys.exit()
+            output_file = os.path.join(output_dir, f'parsedrich_{corpus}')
+        else:
+            output_file = os.path.join(output_dir, f'parsed_{corpus}')
         if not os.path.isdir(output_dir):
-                os.system(f"mkdir {output_dir}")
+                os.mkdir(output_dir)
         if os.path.isfile(output_file + '.npz'):
                 print(f'{corpus} already parsed. do you wish to overwrite? (Y/n)')
                 user = input()
@@ -23,11 +29,11 @@ def parse(corpus):
                     print('done')
                     sys.exit()
 
-        column_names = ['tok_id','tok','1','2','3','4','5','6','7','label']
-        if corpus == 'eng.rst.rstdt_train.conllu':
-            df = pd.read_csv(input_file, names = column_names, skip_blank_lines=True, comment="#", sep="\t", engine='python', error_bad_lines=False, header=None)
-        else:
-            df = pd.read_csv(input_file, names = column_names, skip_blank_lines=True, comment="#", sep="\t")
+        column_names = ['tok_id','tok','lemma','upos','xpos','gram','idhead','deprel','type','label']
+        #if corpus.startswith('eng.rst.rstdt_train') or corpus.startswith('por.rst.cstn'):
+        #    df = pd.read_csv(input_file, names = column_names, skip_blank_lines=True, sep="\t", quoting=3, comment='#')
+        #else:
+        df = pd.read_csv(input_file, names = column_names, skip_blank_lines=True, comment="#", sep="\t", quoting=3)
         tok_ids = df['tok_id'].values
         toks = df['tok'].values
         labels = df['label'].values
@@ -38,21 +44,35 @@ def parse(corpus):
                     new_labels[i] = 1
                 else: new_labels[i] = 0
         else:
-            new_labels = np.where((labels == ("BeginSeg=Yes")), 1, 0) #labels == BeginSeg=Yes
+            new_labels = np.where((labels == "BeginSeg=Yes"), 1, 0) 
         labels = new_labels
         nb_segs = np.sum(labels)
-        
-        np.savez_compressed(output_file, tok_ids = tok_ids, toks = toks, labels = labels)
+       
+        if rich:
+            upos = df['upos'].values
+            idheads = df['idhead'].values
+            deprels = df['deprel'].values
+            np.savez_compressed(output_file, tok_ids = tok_ids, toks = toks, labels = labels, upos = upos, idheads = idheads, deprels = deprels)
+
+        else:
+            np.savez_compressed(output_file, tok_ids = tok_ids, toks = toks, labels = labels)
         
         print(f'done parsing. data saved at {output_file}')
 
 def main():
-        if len(sys.argv) < 2:
-                print("usage: parse_corpus.py <corpus>")
+        if len(sys.argv) < 2 or len(sys.argv) > 3:
+            print("usage: parse_corpus.py <corpus> [<rich>]")
+            sys.exit()
+        corpus = sys.argv[1]
+        rich = False
+        #rich = True means we collect not only token and label but also dependency relationship for conllu files
+        if len(sys.argv) > 2:
+            if sys.argv[2] == 'rich':
+                rich = True
+            else:
+                print("usage: parse_corpus.py <corpus> [<rich>]")
                 sys.exit()
-        corpora = sys.argv[1:]
-        for corpus in corpora:
-                parse(corpus)
+        parse(corpus, rich)
 
 if __name__ == '__main__':
         main()
diff --git a/trytorch/parsed_data/parsed_fra.sdrt.annodis_train-mini.conllu.npz b/trytorch/parsed_data/parsed_fra.sdrt.annodis_train-mini.conllu.npz
deleted file mode 100644
index 6821ad66bd027b64da05151676c2a8d414e9875f..0000000000000000000000000000000000000000
Binary files a/trytorch/parsed_data/parsed_fra.sdrt.annodis_train-mini.conllu.npz and /dev/null differ
diff --git a/trytorch/test_model.py b/trytorch/test_model.py
index 26bc6ff327183bf136dc864533a4959326157944..df60cfcbe008aa6ca28bf679a80a44e1a211b73c 100644
--- a/trytorch/test_model.py
+++ b/trytorch/test_model.py
@@ -1,27 +1,28 @@
 import numpy as np
 import sys
 import os
-from transformers import BertTokenizer, BertModel
-from transformers.tokenization_utils_base import BatchEncoding
+import argparse
 import torch
 from torch import nn
-from torch.nn.utils.rnn import pad_sequence
 from torch.utils.data import DataLoader
-#from torchmetrics import F1Score
-#from torch.autograd import Variable
 from tqdm import tqdm
-from train_model_baseline import LSTM, SentenceBatch, generate_sentence_list, toks_to_ids, make_labels, make_tok_types, make_tok_masks, collate_batch
+from train_model_baseline import LSTM, SentenceBatch, generate_sentence_list, collate_batch
+from train_model_rich import RichLSTM, generate_rich_sentence_list, collate_rich_batch
 
 bert = 'bert-base-multilingual-cased'
-tokenizer = BertTokenizer.from_pretrained(bert)
 
-def test(corpus, model_path, test_set, fmt):
+def test(model_path, model_type, corpus, test_set, fmt, show_errors):
         model = torch.load(model_path)
-        print(f'Model:\t{model_path}\nEval:\t{corpus}_{test_set}\nFormat:\t{fmt}')
-        data = generate_sentence_list(corpus, test_set, fmt)
+        print(f'Model:\t{model_path}\nType:\t{model_type}\nEval:\t{corpus}_{test_set}\nFormat:\t{fmt}')
+
         batch_size = 32
-        dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
-        
+        if model_type == 'baseline':
+            data = generate_sentence_list(corpus, test_set, fmt)
+            dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
+        else:
+            data, _, _ = generate_rich_sentence_list(corpus, test_set, fmt)
+            dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_rich_batch)
+
         model.eval()
         loss_fn = nn.BCELoss()
         
@@ -29,27 +30,25 @@ def test(corpus, model_path, test_set, fmt):
         errors = []
         
         with torch.no_grad():
-                total_acc = 0
-                total_loss = 0
-                tp = 0
-                fp = 0 
-                fn = 0
+                total_acc, total_loss = 0, 0
+                tp, fp, fn = 0, 0, 0
         
                 for sentence_batch in tqdm(dataloader):
                         label_batch = sentence_batch.labels
-                        #print("label_batch", label_batch.shape)
-                        pred = model(sentence_batch.getBatchEncoding())
-                        #print("pred", pred.shape)
+                        if model_type == 'baseline':
+                            pred = model(sentence_batch.getBatchEncoding())
+                        else:
+                            pred = model(sentence_batch.getBatchEncoding(), sentence_batch.upos, sentence_batch.deprels, sentence_batch.dheads)
                         loss = loss_fn(pred, label_batch)
                         pred_binary = (pred >= 0.5).float()
                         
                         for i in range(label_batch.size(0)):
                                 for j in range(label_batch.size(1)):
-                                                if pred_binary[i,j] == 1.:
-                                                        if label_batch[i,j] == 1.:
-                                                                tp += 1
-                                                        else: fp += 1
-                                                elif label_batch[i,j] == 1.: fn += 1
+                                        if pred_binary[i,j] == 1.:
+                                                if label_batch[i,j] == 1.:
+                                                        tp += 1
+                                                else: fp += 1
+                                        elif label_batch[i,j] == 1.: fn += 1
                         sum_score = ((pred_binary == label_batch).float().sum().item()) / label_batch.size(1)
                         assert (sum_score <= batch_size)
                         for i, sentence_id in enumerate(sentence_batch.sentence_ids):
@@ -61,8 +60,9 @@ def test(corpus, model_path, test_set, fmt):
                 precision = tp / (tp + fp)
                 recall = tp / (tp + fn)
                 f1 = 2 * (precision * recall) / (precision + recall)
-
-        #print_errors(errors, data)
+    
+        if show_errors > 0:
+            print_errors(errors, data, max_print=show_errors)
 
         print(f"Acc\t{total_acc/l}\nLoss\t{total_loss/l}\nP\t{precision}\nR\t{recall}\nF1\t{f1}\n\n")
 
@@ -72,10 +72,9 @@ def print_segmentation(toks, labels):
                 if i+1 < len(labels) and labels[i+1] == 1:
                         s += "| "
                 s += str(tok) + " "
-        print(s)
+        print(s + '\n')
 
 def print_errors(errors, data, max_print=25):
-        
         print(f'Reporting {max_print} errors')
         max_print = min(max_print, len(errors))
         for sentence_id, pred in errors[:max_print]:
@@ -88,32 +87,23 @@ def print_errors(errors, data, max_print=25):
                 print_segmentation(toks, labels)
 
 def main():
-        if len(sys.argv) < 2 or len(sys.argv) > 4:
-                print("usage: test_model.py <model> [<test_corpus>] [<test/dev/train>]")
-                sys.exit()
-        model_path = sys.argv[1]
-        if not os.path.isfile(model_path):
-                print("model not found. please train the model first. please provide a relative path from discut dir.")
-                sys.exit()
-        test_set = 'test'
-        if len(sys.argv) == 4:
-                if sys.argv[3] == 'dev':
-                        test_set = 'dev'
-                elif sys.arvg[3] == 'train':
-                        test_set = 'train'
-                elif sys.argv[3] != 'test':        
-                        print("usage: test_model.py <model> [<test_corpus>] [<test/dev/train>]")
-                        sys.exit()
-        model_split = model_path.split("/")
-        corpus = model_split[1]        
-        if len(sys.argv) >= 3:
-                corpus = sys.argv[2] 
-        fmt = 'conllu'
-        params = model_split[2]
-        params_split = params.split("_")
-        if params_split[1] == 'tok':
-                fmt = 'tok'
-        test(corpus, model_path, test_set, fmt)        
+        parser = argparse.ArgumentParser(description='Test a model on a dataset')
+        parser.add_argument('--model', help='Path to .pth saved model')
+        parser.add_argument('--format', default='conllu', help='tok or conllu')
+        parser.add_argument('--type', default='baseline', help="baseline or rich model")
+        parser.add_argument('--corpus', help='corpus to test on')
+        parser.add_argument('--set', default='test', help='portion of the corpus to test on')
+        parser.add_argument('--errors', default=0, help='number of prediction errors to display on standard output')
+
+        params = parser.parse_args()
+
+        if not os.path.isfile(params.model):
+            print("model not found. please train the model first. please provide a relative path from discut dir.")
+            sys.exit()
+        if params.type == 'rich' and params.format == 'tok':
+            print('a rich model requires a .conllu file')
+            sys.exit()
+        test(params.model, params.type, params.corpus, params.set, params.format, params.errors)
 
 if __name__ == '__main__':
         main()
diff --git a/trytorch/train_model_baseline.py b/trytorch/train_model_baseline.py
index 4acbd84cf3809e3c764d78118548bfeb392c53dd..b6038b51e56698c134c87b1aa2a6fddaccad3066 100644
--- a/trytorch/train_model_baseline.py
+++ b/trytorch/train_model_baseline.py
@@ -12,7 +12,6 @@ from torch.utils.data import DataLoader
 from tqdm import tqdm
 
 bert = 'bert-base-multilingual-cased'
-tokenizer = BertTokenizer.from_pretrained(bert)
 #bert_embeddings = BertModel.from_pretrained(bert)
 
 class LSTM(nn.Module):
@@ -42,14 +41,16 @@ class LSTM(nn.Module):
                 #lstm_out, self.hidden = self.lstm(output, self.hidden)
 
 class SentenceBatch():
-
-        def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels):
+        def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels, uposes = None, deprels = None, dheads = None):
                 self.sentence_ids = sentence_ids
                 self.tokens = tokens
-                self.tok_ids = pad_sequence(tok_ids, batch_first=True)
+                self.tok_ids = pad_sequence(tok_ids, batch_first=True) #bert token ids
                 self.tok_types = pad_sequence(tok_types, batch_first=True)
                 self.tok_masks = pad_sequence(tok_masks, batch_first=True)
                 self.labels = pad_sequence(labels, batch_first=True)
+                self.uposes = pad_sequence(upos, batch_first=True)
+                self.dheads = pad_sequence(dheads, batch_first=True)
+                self.labels = pad_sequence(labels, batch_first=True)
 
         def getBatchEncoding(self):
                 dico = { 'input_ids': self.tok_ids, 'token_type_ids': self.tok_types, 'attention_mask': self.tok_masks }
@@ -102,6 +103,7 @@ def toks_to_ids(sentence):
         #print("sentence=", sentence)
         tokens = ['[CLS]'] + list(sentence) + ['[SEP]']
         #print("tokens=", tokens)
+        tokenizer = BertTokenizer.from_pretrained(bert)
         return torch.tensor(tokenizer.convert_tokens_to_ids(tokens)) #len(tokens)
         #print("token_ids=", token_ids)
 
@@ -189,10 +191,10 @@ def train(corpus, fmt):
                 print(f"Epoch {epoch} Accuracy {total_acc/l} Loss {total_loss/l} F1 {f1}")
         
         print('done training')
-        output_file = save_model(model, corpus, params)
+        output_file = save_model(model, 'baseline', corpus, params)
         print(f'model saved at {output_file}')
 
-def save_model(model, corpus, params):
+def save_model(model, model_type, corpus, params):
 
         models_dir = 'saved_models'
         if not os.path.isdir(models_dir):
@@ -200,7 +202,7 @@ def save_model(model, corpus, params):
         corpus_dir = os.path.join(models_dir, corpus)        
         if not os.path.isdir(corpus_dir):
                 os.system(f'mkdir {corpus_dir}')
-        model_file = f"{corpus}_{params['fm']}_{params['bs']}_{params['ne']}_{params['lr']}_{params['rg']}_{params['do']}_{params['bi']}.pth"
+        model_file = f"{model_type}_{corpus}_{params['fm']}_{params['bs']}_{params['ne']}_{params['lr']}_{params['rg']}_{params['do']}_{params['bi']}.pth"
         output_file = os.path.join(corpus_dir, model_file)
         if not os.path.isfile(output_file):
                 torch.save(model, output_file)
diff --git a/trytorch/train_model_rich.py b/trytorch/train_model_rich.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccb72b33f2634cc2ae705f55f0be42bce7fa6bda
--- /dev/null
+++ b/trytorch/train_model_rich.py
@@ -0,0 +1,233 @@
+import numpy as np
+import sys
+import os
+from transformers import BertModel
+from transformers.tokenization_utils_base import BatchEncoding
+import torch
+from torch import nn
+from torch.nn.utils.rnn import pad_sequence
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+from train_model_baseline import SentenceBatch, add_cls_sep, toks_to_ids, make_labels, make_tok_types, make_tok_masks, save_model
+
+bert = 'bert-base-multilingual-cased'
+#bert_embeddings = BertModel.from_pretrained(bert)
+
+class RichLSTM(nn.Module):
+
+        def __init__(self, batch_size, hidden_size, n_upos, upos_emb_dim, n_deprel, deprel_emb_dim, num_layers=1, bidirectional=False):
+                super().__init__()
+                self.batch_size = batch_size
+                self.hidden_size = hidden_size
+
+                #embedding dims 
+                self.bert_emb_dim = 768
+                self.upos_emb_dim = upos_emb_dim
+                self.deprel_emb_dim = deprel_emb_dim
+
+                #embedding layers
+                self.bert_embeddings = BertModel.from_pretrained(bert)
+                self.upos_embeddings = nn.Embedding(n_upos + 1, self.upos_emb_dim, padding_idx=0)
+                self.deprel_embeddings = nn.Embedding(n_deprel + 1, self.deprel_emb_dim, padding_idx=0)
+                lstm_input_size = self.bert_emb_dim + self.upos_emb_dim + self.deprel_emb_dim + 1
+                self.lstm = nn.LSTM(lstm_input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
+                d = 2 if bidirectional else 1
+                self.hiddenToLabel = nn.Linear(d * hidden_size, 1)
+                self.act = nn.Sigmoid()
+
+                #convenient values for slicing the vectors: [ bert (768) | upos (U) | deprel (R) | dhead (1) ]
+                self.bert_upos = self.bert_emb_dim + self.upos_emb_dim
+                self.bert_upos_deprel = self.bert_emb_dim + self.upos_emb_dim + self.deprel_emb_dim
+
+        def forward(self, tok_batch, upos_batch, deprel_batch, dhead_batch):
+                # batch: [B x L], where
+                # B = batch_size
+                # L = max sentence length in batch
+                
+                bert_output = self.bert_embeddings(**tok_batch)
+                bert_output = bert_output.last_hidden_state
+                # bert_output: [B x L x 768]
+                upos_output = self.upos_embeddings(upos_batch)
+                # upos_output: [B x L x U]
+                deprel_output = self.deprel_embeddings(deprel_batch)
+                # deprel_output: [B x L x R]
+                dhead_output = dhead_batch[:,:,None]
+                # dhead_output: [B x L x 1]
+                #print("bert_output=", bert_output.shape)
+                #print("upos_output=", upos_output.shape)
+                #print("deprel_output", deprel_output.shape)
+                #print("dhead_output", dhead_output.shape)
+
+                full_output = torch.cat((bert_output, upos_output, deprel_output, dhead_output), dim=2)
+                output64, (last_hidden_state, last_cell_state) = self.lstm(full_output)
+                output1 = self.hiddenToLabel(output64)
+                #print("output1=", output1.shape)
+                return self.act(output1[:,:,0])
+
+                #lstm_out, self.hidden = self.lstm(output, self.hidden)
+
+#class RichSentenceBatch():
+#
+#        def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, upos, deprels, dheads, labels):
+#                self.sentence_ids = sentence_ids
+#                self.tokens = tokens
+#                self.tok_ids = pad_sequence(tok_ids, batch_first=True)
+#                self.tok_types = pad_sequence(tok_types, batch_first=True)
+#                self.tok_masks = pad_sequence(tok_masks, batch_first=True)
+#                self.upos = pad_sequence(upos, batch_first=True)
+#                self.deprels = pad_sequence(deprels, batch_first=True) #LOG SCALE
+#                self.dheads = pad_sequence(dheads, batch_first=True)
+#                self.labels = pad_sequence(labels, batch_first=True)
+#
+#        def getBatchEncoding(self):
+#                dico = { 'input_ids': self.tok_ids, 'token_type_ids': self.tok_types, 'attention_mask': self.tok_masks }
+#                return BatchEncoding(dico)
+        
+def cat_to_id(instance, cat):
+        return np.where(cat == instance)[0][0] + 1
+
+def dh_list(tok_ids, idheads, s, e):
+        l = e - s
+        ret = [0] * l 
+        new_tok_ids = tok_ids[s:e]
+        new_idheads = idheads[s:e]
+        for i, idhead in enumerate(new_idheads):
+            if idhead == 0:
+                ret[i] = 0
+            else:
+                ret[i] = np.log(abs(idhead - new_tok_ids[i]))
+        return ret
+
+def generate_rich_sentence_list(corpus, sset, fmt):
+        parsed_data = os.path.join("parsed_data", f"parsedrich_{corpus}_{sset}.{fmt}.npz")
+        #print("PATH", parsed_data)
+        if not os.path.isfile(parsed_data):
+                print("you must parse the corpus before training it")
+                sys.exit()
+        data = np.load(parsed_data, allow_pickle = True)
+        tok_ids, toks, labels, upos, idheads, deprels = [data[f] for f in data.files]
+        
+        unique_upos = np.unique(upos)
+        unique_deprels = np.unique(deprels)
+        nb_upos = len(unique_upos)
+        nb_deprels = len(unique_deprels)
+
+        sentences = []
+
+        previous_i = 0
+        for index, tok_id in enumerate(tok_ids):
+                if index > 0 and (tok_id == 1 or tok_id == '1' or (isinstance(tok_id, str) and tok_id.startswith('1-'))):
+                        if index - previous_i <= 510:
+                                sentences.append((toks[previous_i:index], [cat_to_id(u, unique_upos) for u in upos[previous_i:index]], [cat_to_id(d, unique_deprels) for d in deprels[previous_i:index]], dh_list(tok_ids, idheads, previous_i, index), labels[previous_i:index]))
+                                previous_i = index
+                        else: 
+                                sep = previous_i + 510
+                                sentences.append((toks[previous_i:sep], labels[previous_i:sep]))
+                                if sep - previous_i > 510:
+                                        print("still too long sentence...")
+                                        sys.exit()
+                                sentences.append((toks[sep:index], [cat_to_id(u, unique_upos) for u in upos[sep:index]], [cat_to_id(d, unique_deprels) for d in deprels[sep:index]], dh_list(tok_ids, idheads, sep, index), labels[sep:index]))
+
+        indexed_sentences = [0] * len(sentences)
+        for i, sentence in enumerate(sentences):
+                indexed_sentences[i] = (i, sentence)
+
+        return indexed_sentences, nb_upos, nb_deprels
+
+def int_add_zeros(sentence):
+        zero = torch.zeros(1).int()
+        return torch.cat((zero, torch.tensor(sentence).int(), zero))
+
+def add_zeros(sentence):
+        zero = torch.zeros(1)
+        return torch.cat((zero, torch.tensor(sentence), zero))
+
+def collate_rich_batch(batch):
+        sentence_ids, token_batch, upos_batch, idhead_batch, deprel_batch, label_batch = [i for i, (_, _, _, _, _) in batch], [j for _, (j, _, _, _, _) in batch], [k for _, (_, k, _, _, _) in batch], [l for _, (_, _, l, _, _) in batch], [m for _, (_, _, _, m, _) in batch], [n for _, (_, _, _, _, n) in batch]
+        labels = [make_labels(sentence) for sentence in label_batch]
+        tokens = [add_cls_sep(sentence) for sentence in token_batch]
+        tok_ids = [toks_to_ids(sentence) for sentence in token_batch]
+        lengths = [len(toks) for toks in tok_ids]
+        tok_types = [make_tok_types(l) for l in lengths]
+        tok_masks = [make_tok_masks(l) for l in lengths]
+        uposes = [int_add_zeros(sentence) for sentence in upos_batch]
+        deprels = [int_add_zeros(sentence) for sentence in deprel_batch]
+        dheads = [add_zeros(sentence) for sentence in idhead_batch]
+        
+        return SentenceBatch(sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels, uposes = uposes, deprels = deprels, dheads = dheads)
+        
+def train(corpus):
+        print(f'starting rich training of {corpus}...')
+        data, nb_upos, nb_deprels = generate_rich_sentence_list(corpus, 'train', 'conllu') 
+        upos_emb_dim = int(np.sqrt(nb_upos))
+        deprel_emb_dim = int(np.sqrt(nb_deprels))
+
+        torch.manual_seed(1)
+
+        #PARAMETERS
+        batch_size = 32 
+        num_epochs = 5
+        lr = 0.0001
+        reg = 0.001
+        dropout = 0.01
+        bidirectional = True
+        params = { 'fm' : 'conllu', 'bs': batch_size, 'ne': num_epochs, 'lr': lr, 'rg': reg, 'do': dropout, 'bi': bidirectional }
+
+        dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_rich_batch)
+        model = RichLSTM(batch_size, 64, nb_upos, upos_emb_dim, nb_deprels, deprel_emb_dim, num_layers=1, bidirectional=bidirectional) 
+        loss_fn = nn.BCELoss() 
+        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=reg)
+        model.train()
+        
+        l = len(dataloader.dataset)
+        
+        for epoch in range(num_epochs):
+                total_acc = 0
+                total_loss = 0
+                tp = 0
+                fp = 0 
+                fn = 0
+        
+                for sentence_batch in tqdm(dataloader):
+                        optimizer.zero_grad()
+                        label_batch = sentence_batch.labels
+                        #print("label_batch", label_batch.shape)
+                        #print("tok_ids", sentence_batch.tok_ids.shape)
+                        pred = model(sentence_batch.getBatchEncoding(), sentence_batch.uposes, sentence_batch.deprels, sentence_batch.dheads)
+                        #print("pred", pred.shape)
+                        loss = loss_fn(pred, label_batch)
+                        loss.backward()
+                        optimizer.step()
+                        pred_binary = (pred >= 0.5).float()
+                        
+                        for i in range(label_batch.size(0)):
+                                for j in range(label_batch.size(1)):
+                                        if pred_binary[i,j] == 1.:
+                                                if label_batch[i,j] == 1.:
+                                                        tp += 1
+                                                else: fp += 1
+                                        elif label_batch[i,j] == 1.: fn += 1
+                        #print("tp,fp,fn=",tp/label_batch.size(1),fp/label_batch.size(1),fn/label_batch.size(1))
+                        #nb_1 = pred_binary.sum()
+                        #print("nb predicted 1=", nb_1)
+                        sum_score = ((pred_binary == label_batch).float().sum().item()) / label_batch.size(1)
+                        assert (sum_score <= batch_size)
+                        total_acc += sum_score
+                        total_loss += loss.item() #*label_batch.size(0)
+
+                f1 = tp / (tp + (fp + fn) / 2)
+                print(f"Epoch {epoch} Accuracy {total_acc/l} Loss {total_loss/l} F1 {f1}")
+        
+        print('done training')
+        output_file = save_model(model, 'rich', corpus, params)
+        print(f'model saved at {output_file}')
+
+def main():
+        if len(sys.argv) != 2:
+                print("usage: train_model_rich.py <corpus>")
+                sys.exit()
+        corpus = sys.argv[1]
+        train(corpus) 
+
+if __name__ == '__main__':
+        main()