diff --git a/trytorch/train_model_baseline.py b/trytorch/train_model_baseline.py
index c772b818c5f0b950ad6c729a7d919224e407356a..7b9dfecc7b5c785a529975c2d6d9ed60baa45256 100644
--- a/trytorch/train_model_baseline.py
+++ b/trytorch/train_model_baseline.py
@@ -1,237 +1,230 @@
 import numpy as np
 import sys
 import os
+import argparse
 from transformers import BertTokenizer, BertModel
 from transformers.tokenization_utils_base import BatchEncoding
 import torch
 from torch import nn
 from torch.nn.utils.rnn import pad_sequence
 from torch.utils.data import DataLoader
-#from torchmetrics import F1Score
-#from torch.autograd import Variable
 from tqdm import tqdm
 
 bert = 'bert-base-multilingual-cased'
-#bert_embeddings = BertModel.from_pretrained(bert)
+tokenizer = BertTokenizer.from_pretrained(bert)
 
 class LSTM(nn.Module):
 
-        def __init__(self, batch_size, input_size, hidden_size, num_layers=1, bidirectional=False):
-                super().__init__()
-                self.batch_size = batch_size
-                self.hidden_size = hidden_size
-                self.bert_embeddings = BertModel.from_pretrained(bert)
-                self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
-                d = 2 if bidirectional else 1
-                self.hiddenToLabel = nn.Linear(d * hidden_size, 1)
-                self.act = nn.Sigmoid()
-
-        def forward(self, batch):
-                output = self.bert_embeddings(**batch)
-                output768 = output.last_hidden_state
-                output64, (last_hidden_state, last_cell_state) = self.lstm(output768)
-                #output64, self.hidden = self.lstm(output768, self.hidden)
-                #print("output64=", output64.shape)
-                #print("last_hidden_state", last_hidden_state.shape)
-                #print("last_cell_state", last_cell_state.shape)
-                output1 = self.hiddenToLabel(output64)
-                #print("output1=", output1.shape)
-                return self.act(output1[:,:,0])
-
-                #lstm_out, self.hidden = self.lstm(output, self.hidden)
+    def __init__(self, batch_size, input_size, hidden_size, n_layers=1, bidirectional=False):
+        super().__init__()
+        self.batch_size = batch_size
+        self.hidden_size = hidden_size
+        self.bert_embeddings = BertModel.from_pretrained(bert)
+        self.lstm = nn.LSTM(input_size, hidden_size, n_layers, batch_first=True, bidirectional=bidirectional)
+        d = 2 if bidirectional else 1
+        self.hiddenToLabel = nn.Linear(d * hidden_size, 1)
+        self.act = nn.Sigmoid()
+
+    def forward(self, batch):
+        output = self.bert_embeddings(**batch)
+        output768 = output.last_hidden_state
+        output64, (last_hidden_state, last_cell_state) = self.lstm(output768)
+        #output64, self.hidden = self.lstm(output768, self.hidden)
+        #print("output64=", output64.shape)
+        #print("last_hidden_state", last_hidden_state.shape)
+        #print("last_cell_state", last_cell_state.shape)
+        output1 = self.hiddenToLabel(output64)
+        #print("output1=", output1.shape)
+        return self.act(output1[:,:,0])
+
+        #lstm_out, self.hidden = self.lstm(output, self.hidden)
 
 class SentenceBatch():
-        def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels, uposes = None, deprels = None, dheads = None):
-                self.sentence_ids = sentence_ids
-                self.tokens = tokens
-                self.tok_ids = pad_sequence(tok_ids, batch_first=True) #bert token ids
-                self.tok_types = pad_sequence(tok_types, batch_first=True)
-                self.tok_masks = pad_sequence(tok_masks, batch_first=True)
-                self.labels = pad_sequence(labels, batch_first=True)
-                if uposes is not None:
-                    self.uposes = pad_sequence(uposes, batch_first=True)
-                else: self.uposes = None
-                if deprels is not None:
-                    self.deprels = pad_sequence(deprels, batch_first=True)
-                else: self.deprels = None
-                if dheads is not None:
-                    self.dheads = pad_sequence(dheads, batch_first=True)
-                else: self.dheads = None
-                self.labels = pad_sequence(labels, batch_first=True)
-
-        def getBatchEncoding(self):
-                dico = { 'input_ids': self.tok_ids, 'token_type_ids': self.tok_types, 'attention_mask': self.tok_masks }
-                return BatchEncoding(dico)
-        
+    def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels, uposes = None, deprels = None, dheads = None):
+        self.sentence_ids = sentence_ids
+        self.tokens = tokens
+        self.tok_ids = pad_sequence(tok_ids, batch_first=True) #bert token ids
+        self.tok_types = pad_sequence(tok_types, batch_first=True)
+        self.tok_masks = pad_sequence(tok_masks, batch_first=True)
+        self.labels = pad_sequence(labels, batch_first=True)
+        if uposes is not None:
+            self.uposes = pad_sequence(uposes, batch_first=True)
+        else: self.uposes = None
+        if deprels is not None:
+            self.deprels = pad_sequence(deprels, batch_first=True)
+        else: self.deprels = None
+        if dheads is not None:
+            self.dheads = pad_sequence(dheads, batch_first=True)
+        else: self.dheads = None
+        self.labels = pad_sequence(labels, batch_first=True)
+
+    def getBatchEncoding(self):
+        dico = { 'input_ids': self.tok_ids, 'token_type_ids': self.tok_types, 'attention_mask': self.tok_masks }
+        return BatchEncoding(dico)
+    
 def generate_sentence_list(corpus, sset, fmt):
-        #move that part to parse_corpus.py
-        parsed_data = os.path.join("parsed_data", f"parsed_{corpus}_{sset}.{fmt}.npz")
-        #print("PATH", parsed_data)
-        if not os.path.isfile(parsed_data):
-                print("you must parse the corpus before training it")
-                sys.exit()
-        data = np.load(parsed_data, allow_pickle = True)
-        tok_ids, toks, labels = [data[f] for f in data.files]
-
-        sentences = []
-
-        if fmt == 'conllu':
-                previous_i = 0
-                for index, tok_id in enumerate(tok_ids):
-                        if index > 0 and (tok_id == 1 or tok_id == '1' or (isinstance(tok_id, str) and tok_id.startswith('1-'))):
-                                if index - previous_i <= 510:
-                                        sentences.append((toks[previous_i:index], labels[previous_i:index]))
-                                        previous_i = index
-                                else: 
-                                        sep = previous_i + 510
-                                        sentences.append((toks[previous_i:sep], labels[previous_i:sep]))
-                                        if sep - previous_i > 510:
-                                                print("still too long sentence...")
-                                                sys.exit()
-                                        sentences.append((toks[sep:index], labels[sep:index]))
-        else: 
-                max_length = 510
-                nb_toks = len(tok_ids)
-                for i in range(0, nb_toks - max_length, max_length):
-                        #add slices of 510 tokens
-                        sentences.append((toks[i:(i + max_length)], labels[i:(i + max_length)]))
-                sentences.append((toks[-(nb_toks % max_length):], labels[-(nb_toks % max_length):]))
-
-        indexed_sentences = [0] * len(sentences)
-        for i, sentence in enumerate(sentences):
-                indexed_sentences[i] = (i, sentence)
-
-        return indexed_sentences
+    #move that part to parse_corpus.py
+    parsed_data = os.path.join("parsed_data", f"parsed_{corpus}_{sset}.{fmt}.npz")
+    #print("PATH", parsed_data)
+    if not os.path.isfile(parsed_data):
+        print("you must parse the corpus before training it")
+        sys.exit()
+    data = np.load(parsed_data, allow_pickle = True)
+    tok_ids, toks, labels = [data[f] for f in data.files]
+
+    sentences = []
+
+    if fmt == 'conllu':
+        previous_i = 0
+        for index, tok_id in enumerate(tok_ids):
+            if index > 0 and (tok_id == 1 or tok_id == '1' or (isinstance(tok_id, str) and tok_id.startswith('1-'))):
+                if index - previous_i <= 510:
+                    sentences.append((toks[previous_i:index], labels[previous_i:index]))
+                    previous_i = index
+                else: 
+                    sep = previous_i + 510
+                    sentences.append((toks[previous_i:sep], labels[previous_i:sep]))
+                    if sep - previous_i > 510:
+                        print("still too long sentence...")
+                        sys.exit()
+                    sentences.append((toks[sep:index], labels[sep:index]))
+    else: 
+        max_length = 510
+        nb_toks = len(tok_ids)
+        for i in range(0, nb_toks - max_length, max_length):
+            #add slices of 510 tokens
+            sentences.append((toks[i:(i + max_length)], labels[i:(i + max_length)]))
+        sentences.append((toks[-(nb_toks % max_length):], labels[-(nb_toks % max_length):]))
+
+    indexed_sentences = [0] * len(sentences)
+    for i, sentence in enumerate(sentences):
+        indexed_sentences[i] = (i, sentence)
+
+    return indexed_sentences
 
 def add_cls_sep(sentence):
-        return ['[CLS]'] + list(sentence) + ['[SEP]']
-        
+    return ['[CLS]'] + list(sentence) + ['[SEP]']
+    
 def toks_to_ids(sentence):
-        #print("sentence=", sentence)
-        tokens = ['[CLS]'] + list(sentence) + ['[SEP]']
-        #print("tokens=", tokens)
-        tokenizer = BertTokenizer.from_pretrained(bert)
-        return torch.tensor(tokenizer.convert_tokens_to_ids(tokens)) #len(tokens)
-        #print("token_ids=", token_ids)
+    #print("sentence=", sentence)
+    tokens = ['[CLS]'] + list(sentence) + ['[SEP]']
+    #print("tokens=", tokens)
+    return torch.tensor(tokenizer.convert_tokens_to_ids(tokens)) #len(tokens)
+    #print("token_ids=", token_ids)
 
 def make_labels(sentence):
-        zero = np.array([0])
-        add_two = np.concatenate((np.concatenate((zero, sentence)), zero)) #add label 0 for [CLS] and [SEP]
-        return torch.from_numpy(add_two).float()
+    zero = np.array([0])
+    add_two = np.concatenate((np.concatenate((zero, sentence)), zero)) #add label 0 for [CLS] and [SEP]
+    return torch.from_numpy(add_two).float()
 
 def make_tok_types(l):
-        return torch.zeros(l, dtype=torch.int32)
+    return torch.zeros(l, dtype=torch.int32)
 
 def make_tok_masks(l):
-        return torch.ones(l, dtype=torch.int32)
+    return torch.ones(l, dtype=torch.int32)
 
 def collate_batch(batch):
-        sentence_ids, token_batch, label_batch = [i for i, (_, _) in batch], [j for _, (j, _) in batch], [k for _, (_, k) in batch]
-        #mappings = [make_mapping(sentence) for sentence in token_batch]
-        labels = [make_labels(sentence) for sentence in label_batch]
-        tokens = [add_cls_sep(sentence) for sentence in token_batch]
-        tok_ids = [toks_to_ids(sentence) for sentence in token_batch]
-        lengths = [len(toks) for toks in tok_ids]
-        tok_types = [make_tok_types(l) for l in lengths]
-        tok_masks = [make_tok_masks(l) for l in lengths]
-        
-        return SentenceBatch(sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels)
-        
+    sentence_ids, token_batch, label_batch = [0] * len(batch), [0] * len(batch), [0] * len(batch)
+
+    for i, (ids, (toks, labs)) in enumerate(batch):
+        sentence_ids[i] = ids
+        token_batch[i] = toks
+        label_batch[i] = labs
+    
+    labels = [make_labels(sentence) for sentence in label_batch]
+    tokens = [add_cls_sep(sentence) for sentence in token_batch]
+    tok_ids = [torch.tensor(tokenizer.convert_tokens_to_ids(sentence)) for sentence in tokens]
+    lengths = [len(toks) for toks in tok_ids]
+    tok_types = [make_tok_types(l) for l in lengths]
+    tok_masks = [make_tok_masks(l) for l in lengths]
+    
+    return SentenceBatch(sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels)
+    
 def train(corpus, fmt):
-        print(f'starting training of {corpus} in format {fmt}...')
-        data = generate_sentence_list(corpus, 'train', fmt)
-
-        torch.manual_seed(1)
-
-        #PARAMETERS
-        batch_size = 32 if fmt == 'conllu' else 4
-        num_epochs = 10
-        lr = 0.0001
-        reg = 0.0005
-        dropout = 0.01
-        bidirectional = True
-        params = { 'fm' : fmt, 'bs': batch_size, 'ne': num_epochs, 'lr': lr, 'rg': reg, 'do': dropout, 'bi': bidirectional }
-
-        dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
-        model = LSTM(batch_size, 768, 64, num_layers=1, bidirectional=bidirectional) 
-        loss_fn = nn.BCELoss() #BCELoss
-        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=reg)
-        model.train()
-        
-        l = len(dataloader.dataset)
-        
-        for epoch in range(num_epochs):
-                total_acc = 0
-                total_loss = 0
-                tp = 0
-                fp = 0 
-                fn = 0
-        
-                for sentence_batch in tqdm(dataloader):
-                        optimizer.zero_grad()
-                        label_batch = sentence_batch.labels
-                        #print("label_batch", label_batch.shape)
-                        #print("tok_ids", sentence_batch.tok_ids.shape)
-                        pred = model(sentence_batch.getBatchEncoding())
-                        #print("pred", pred.shape)
-                        loss = loss_fn(pred, label_batch)
-                        loss.backward()
-                        optimizer.step()
-                        pred_binary = (pred >= 0.5).float()
-                        
-                        for i in range(label_batch.size(0)):
-                                for j in range(label_batch.size(1)):
-                                        if pred_binary[i,j] == 1.:
-                                                if label_batch[i,j] == 1.:
-                                                        tp += 1
-                                                else: fp += 1
-                                        elif label_batch[i,j] == 1.: fn += 1
-                        #print("tp,fp,fn=",tp/label_batch.size(1),fp/label_batch.size(1),fn/label_batch.size(1))
-                        #nb_1 = pred_binary.sum()
-                        #print("nb predicted 1=", nb_1)
-                        sum_score = ((pred_binary == label_batch).float().sum().item()) / label_batch.size(1)
-                        assert (sum_score <= batch_size)
-                        total_acc += sum_score
-                        total_loss += loss.item() #*label_batch.size(0)
-
-                f1 = tp / (tp + (fp + fn) / 2)
-                print(f"Epoch {epoch} Accuracy {total_acc/l} Loss {total_loss/l} F1 {f1}")
-        
-        print('done training')
-        output_file = save_model(model, 'baseline', corpus, params)
-        print(f'model saved at {output_file}')
+    print(f'starting training of {corpus} in format {fmt}...')
+    data = generate_sentence_list(corpus, 'train', fmt)
+
+    torch.manual_seed(1)
+
+    #PARAMETERS
+    batch_size = 32 if fmt == 'conllu' else 4
+    n_epochs = 10
+    lr = 0.0001
+    reg = 0.001
+    dropout = 0.1
+    n_layers = 1
+    n_hidden = 64
+    bidirectional = True
+    params = { 'fm' : fmt, 'nl': n_layers, 'nh': n_hidden, 'bs': batch_size, 'ne': n_epochs, 'lr': lr, 'rg': reg, 'do': dropout, 'bi': bidirectional }
+
+    dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
+    model = LSTM(batch_size, 768, n_hidden, n_layers, bidirectional=bidirectional) 
+    loss_fn = nn.BCELoss() 
+    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=reg)
+    model.train()
+    
+    l = len(dataloader.dataset)
+    
+    for epoch in range(n_epochs):
+        total_acc = 0
+        total_loss = 0
+        tp = 0
+        fp = 0 
+        fn = 0
+    
+        for sentence_batch in tqdm(dataloader):
+            optimizer.zero_grad()
+            label_batch = sentence_batch.labels
+            pred = model(sentence_batch.getBatchEncoding())
+            loss = loss_fn(pred, label_batch)
+            loss.backward()
+            optimizer.step()
+            pred_binary = (pred >= 0.5).float()
+            
+            for i in range(label_batch.size(0)):
+                for j in range(label_batch.size(1)):
+                    if pred_binary[i,j] == 1.:
+                        if label_batch[i,j] == 1.:
+                            tp += 1
+                        else: fp += 1
+                    elif label_batch[i,j] == 1.: fn += 1
+            sum_score = ((pred_binary == label_batch).float().sum().item()) / label_batch.size(1)
+            assert (sum_score <= batch_size)
+            total_acc += sum_score
+            total_loss += loss.item() #*label_batch.size(0)
+
+        f1 = tp / (tp + (fp + fn) / 2)
+        print(f"Epoch {epoch} Accuracy {total_acc/l} Loss {total_loss/l} F1 {f1}")
+    
+    print('done training')
+    output_file = save_model(model, 'baseline', corpus, params)
+    print(f'model saved at {output_file}')
 
 def save_model(model, model_type, corpus, params):
 
-        models_dir = 'saved_models'
-        if not os.path.isdir(models_dir):
-                os.system(f'mkdir {models_dir}')
-        corpus_dir = os.path.join(models_dir, corpus)        
-        if not os.path.isdir(corpus_dir):
-                os.system(f'mkdir {corpus_dir}')
-        model_file = f"{model_type}_{corpus}_{params['fm']}_{params['bs']}_{params['ne']}_{params['lr']}_{params['rg']}_{params['do']}_{params['bi']}.pth"
-        output_file = os.path.join(corpus_dir, model_file)
-        if not os.path.isfile(output_file):
-                torch.save(model, output_file)
-        else:
-                print("Model already registers. Do you wish to overwrite? (Y/n)")
-                user = input()
-                if user == "" or user == "Y" or user == "y":
-                        torch.save(model, output_file)
-        return output_file
+    models_dir = 'saved_models'
+    if not os.path.isdir(models_dir):
+        os.system(f'mkdir {models_dir}')
+    corpus_dir = os.path.join(models_dir, corpus)    
+    if not os.path.isdir(corpus_dir):
+        os.system(f'mkdir {corpus_dir}')
+    model_file = f"{model_type}_{corpus}_{params['nl']}_{params['nh']}_{params['fm']}_{params['bs']}_{params['ne']}_{params['lr']}_{params['rg']}_{params['do']}_{params['bi']}.pth"
+    output_file = os.path.join(corpus_dir, model_file)
+    if not os.path.isfile(output_file):
+        torch.save(model, output_file)
+    else:
+        print("Model already exists. Do you wish to overwrite? (Y/n)")
+        user = input()
+        if user == "" or user == "Y" or user == "y":
+            torch.save(model, output_file)
+    return output_file
 
 def main():
-        if len(sys.argv) < 2 or len(sys.argv) > 3:
-                print("usage: train_model_baseline.py <corpus> [<conllu/tok>]")
-                sys.exit()
-        corpus = sys.argv[1]
-        fmt = 'conllu'
-        if len(sys.argv) == 3:
-                fmt = sys.argv[2]
-                if fmt != 'conllu' and fmt != 'tok':
-                        print("usage: train_model_baseline.py <corpus> [<conllu/tok>]")
-                        sys.exit()
-        train(corpus, fmt) 
+    parser = argparse.ArgumentParser(description='Train baseline model')
+    parser.add_argument('--corpus', required=True, help='corpus to train')
+    parser.add_argument('--format', default='conllu', help='tok or conllu')
+    params = parse.parse_args()
+    train(params.corpus, params.format) 
 
 if __name__ == '__main__':
-        main()
+    main()