From 571902b3138a4686c267cdf51df88e3087a048a3 Mon Sep 17 00:00:00 2001
From: emetheni <eleni.metheniti@irit.fr>
Date: Thu, 29 Feb 2024 18:08:36 +0100
Subject: [PATCH] format code to black

---
 README.md                  |   2 +-
 classifier_pytorch.py      | 223 +++++++--------
 configure.py               | 105 ++++---
 make_mappings_zero-shot.py |  50 ++--
 utils.py                   | 555 ++++++++++++++++++++-----------------
 5 files changed, 511 insertions(+), 424 deletions(-)

diff --git a/README.md b/README.md
index 4ee9618..c9ca654 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@ The full list of datasets with statistics: [here](https://github.com/disrpt/shar
 * transformers
 * scikit-learn
 
-Install requirements with ```pip install requirements.txt```.
+Install requirements with ```pip install -r requirements.txt```.
 
 ## Run 
 
diff --git a/classifier_pytorch.py b/classifier_pytorch.py
index 688f6f4..25d5858 100644
--- a/classifier_pytorch.py
+++ b/classifier_pytorch.py
@@ -3,7 +3,12 @@
 
 import torch
 import numpy as np
-from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
+from transformers import (
+    AutoModel,
+    AutoTokenizer,
+    get_linear_schedule_with_warmup,
+    set_seed,
+)
 from torch import nn
 from torch.optim import AdamW
 from torch.utils.data import DataLoader
@@ -23,33 +28,36 @@ now = datetime.now()
 dt_string = now.strftime("%d.%m.%y-%H:%M:%S")
 layers_to_freeze = args.freeze_layers.split(";")
 
-print('\nTraining with datasets: ' + args.langs_to_use)
-print('Mappings file: ' + args.mappings_file, flush='True')
+print("Training with datasets: " + args.langs_to_use)
+print("Mappings file: " + args.mappings_file, flush="True")
 
 
 # ===============
 # Dataset class
 # ===============
 
-class Dataset(torch.utils.data.Dataset):
 
+class Dataset(torch.utils.data.Dataset):
     def __init__(self, sentences):
-
         self.labels = [sent[-1] for sent in sentences]
-        self.texts = [tokenizer(sent[-2], 
-                                is_split_into_words=True,                              
-                                padding='max_length', 
-                                max_length = 512, 
-                                truncation=True,
-                                return_tensors="pt") 
-                                for sent in sentences]
+        self.texts = [
+            tokenizer(
+                sent[-2],
+                is_split_into_words=True,
+                padding="max_length",
+                max_length=512,
+                truncation=True,
+                return_tensors="pt",
+            )
+            for sent in sentences
+        ]
 
     def classes(self):
         return self.labels
 
     def __len__(self):
         return len(self.labels)
-    
+
     def get_batch_labels(self, idx):
         # Fetch a batch of labels
         return np.array(self.labels[idx])
@@ -59,12 +67,12 @@ class Dataset(torch.utils.data.Dataset):
         return self.texts[idx]
 
     def __getitem__(self, idx):
-
         batch_texts = self.get_batch_texts(idx)
         batch_y = self.get_batch_labels(idx)
 
         return batch_texts, batch_y
 
+
 # ===============
 # Load datasets
 # ===============
@@ -72,24 +80,32 @@ class Dataset(torch.utils.data.Dataset):
 # Open mappings
 mappings, inv_mappings = open_mappings(args.mappings_file)
 batch_size = args.batch_size
-tokenizer  = AutoTokenizer.from_pretrained(args.transformer_model)
+tokenizer = AutoTokenizer.from_pretrained(args.transformer_model)
 
-train_sentences, dev_dict_sentences, test_dict_sentences, framework_labels = open_sentences_with_lang(args.data_path, mappings)
+(
+    train_sentences,
+    dev_dict_sentences,
+    test_dict_sentences,
+    framework_labels,
+) = open_sentences_with_lang(args.data_path, mappings)
 
 # Determine linear size (= number of classes in the sets + 1)
 num_labels = len(set(sent[-1] for sent in train_sentences)) + 1
 
 # make train/dev datasets
 train_dataset = Dataset(train_sentences)
-dev_dataset   = {corpus: Dataset(s) for corpus, s in dev_dict_sentences.items()}
-test_dataset  = {corpus: Dataset(s) for corpus, s in test_dict_sentences.items()}
+dev_dataset = {corpus: Dataset(s) for corpus, s in dev_dict_sentences.items()}
+test_dataset = {corpus: Dataset(s) for corpus, s in test_dict_sentences.items()}
 
 # Make dasets with batches and dataloader
 train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
-dev_dict_dataloader = {corpus: DataLoader(dev_data, batch_size) 
-                        for corpus, dev_data in dev_dataset.items()}
-test_dict_dataloader = {corpus: DataLoader(test_data, batch_size) 
-                        for corpus, test_data in test_dataset.items()}
+dev_dict_dataloader = {
+    corpus: DataLoader(dev_data, batch_size) for corpus, dev_data in dev_dataset.items()
+}
+test_dict_dataloader = {
+    corpus: DataLoader(test_data, batch_size)
+    for corpus, test_data in test_dataset.items()
+}
 
 print("\nDatasets loaded!\n")
 
@@ -97,22 +113,20 @@ print("\nDatasets loaded!\n")
 # Model setup
 # ===============
 
-class TransformerClassifier(nn.Module):
 
+class TransformerClassifier(nn.Module):
     def __init__(self, dropout=args.dropout):
-
         super(TransformerClassifier, self).__init__()
 
         self.tr_model = AutoModel.from_pretrained(args.transformer_model)
         self.dropout = nn.Dropout(dropout)
-        self.linear = nn.Linear(768, num_labels) # bert input x num of classes
+        self.linear = nn.Linear(768, num_labels)  # bert input x num of classes
         self.relu = nn.ReLU()
 
     def forward(self, input_id, mask):
-        
-        outputs = self.tr_model(input_ids = input_id, 
-                                attention_mask = mask,
-                                return_dict = True)['last_hidden_state'][:, 0, :]
+        outputs = self.tr_model(
+            input_ids=input_id, attention_mask=mask, return_dict=True
+        )["last_hidden_state"][:, 0, :]
         dropout_output = self.dropout(outputs)
         linear_output = self.linear(dropout_output)
         final_layer = self.relu(linear_output)
@@ -123,133 +137,120 @@ class TransformerClassifier(nn.Module):
 model = TransformerClassifier()
 
 
-def train(model, 
-          train_dataloader, 
-          dev_dict_dataloader, 
-          test_dict_sentences, 
-          test_dict_dataloader,
-          epochs, 
-          #specific_results
-         ):
-
+def train(
+    model,
+    train_dataloader,
+    dev_dict_dataloader,
+    test_dict_sentences,
+    test_dict_dataloader,
+    epochs,
+    # specific_results
+):
     device = torch.device("cpu")
 
     criterion = nn.CrossEntropyLoss()
-    optimizer = AdamW(model.parameters(), #Adam
-                      lr = 2e-5, #1e-6
-                      eps = 1e-8
-                    )
+    optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)  # Adam  # 1e-6
 
-    if args.use_cuda == 'yes':
+    if args.use_cuda == "yes":
         device = torch.device("cuda")
         model = model.cuda()
         criterion = criterion.cuda()
-    
+
     gradient_accumulation_steps = args.gradient_accumulation_steps
     total_steps = len(train_dataloader) * epochs
-    scheduler = get_linear_schedule_with_warmup(optimizer, 
-                                                num_warmup_steps = 0,
-                                                num_training_steps = total_steps)
-    
+    scheduler = get_linear_schedule_with_warmup(
+        optimizer, num_warmup_steps=0, num_training_steps=total_steps
+    )
+
     seed_val = 42
     set_seed(seed_val)
     torch.manual_seed(seed_val)
     torch.cuda.manual_seed_all(seed_val)
-    
+
     # Freeze layers if you want
-    if args.freeze_layers != '':
+    if args.freeze_layers != "":
         for name, param in model.named_parameters():
             if any(x in name for x in layers_to_freeze):
                 param.requires_grad = False
 
     for epoch_num in range(0, epochs):
-        print('\n=== Epoch {:} / {:} ==='.format(epoch_num + 1, epochs))
-        
+        print("\n=== Epoch {:} / {:} ===".format(epoch_num + 1, epochs))
+
         model.train()
 
         total_acc_train = 0
         total_loss_train = 0
         batch_counter = 0
-        
+
         for train_input, train_label in tqdm(train_dataloader):
             batch_counter += 1
             train_label = train_label.to(device)
-            mask = train_input['attention_mask'].to(device)
-            input_id = train_input['input_ids'].squeeze(1).to(device)
+            mask = train_input["attention_mask"].to(device)
+            input_id = train_input["input_ids"].squeeze(1).to(device)
 
             output = model(input_id, mask)
-            
+
             # Compute Loss and Perform Back-propagation
             loss = criterion(output, train_label.long())
 
-
             # Normalize the Gradients
             loss = loss / gradient_accumulation_steps
             loss.backward()
 
-            
-            if (batch_counter % gradient_accumulation_steps == 0):
+            if batch_counter % gradient_accumulation_steps == 0:
                 # Update Optimizer
-                optimizer.step() 
+                optimizer.step()
                 optimizer.zero_grad()
-                
+
                 model.zero_grad()
                 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                 scheduler.step()
-            
+
         # ------ Validation --------
-        
-        print('\nValidation for epoch:', epoch_num + 1)
-        
+
+        print("\nValidation for epoch:", epoch_num + 1)
+
         # Dev and test results for each corpus. We don't need to save the results.
         for corpus in dev_dict_dataloader:
-            dev_results = get_predictions(
-                                model, 
-                                corpus, 
-                                dev_dict_dataloader[corpus]
-                                )
+            dev_results = get_predictions(model, corpus, dev_dict_dataloader[corpus])
             better_dev_results = get_better_predictions(
-                                    model, 
-                                    corpus, 
-                                    dev_dict_dataloader[corpus], 
-                                    framework_labels[corpus.split('.')[1]], 
-                                    inv_mappings,
-                                    epoch_num+1,
-                                    save_conf_matrix=False
-                                    )
-            
+                model,
+                corpus,
+                dev_dict_dataloader[corpus],
+                framework_labels[corpus.split(".")[1]],
+                inv_mappings,
+                epoch_num + 1,
+                save_conf_matrix=False,
+            )
+
         # ------ Test --------
-        
-        print('\nTest results for epoch:', epoch_num + 1)
-        
+
+        print("\nTest results for epoch:", epoch_num + 1)
+
         for corpus in test_dict_dataloader:
-            test_results = get_predictions(
-                                model, 
-                                corpus, 
-                                test_dict_dataloader[corpus]
-                                )
+            test_results = get_predictions(model, corpus, test_dict_dataloader[corpus])
             better_test_results = get_better_predictions(
-                                    model, 
-                                    corpus, 
-                                    test_dict_dataloader[corpus], 
-                                    framework_labels[corpus.split('.')[1]], 
-                                    inv_mappings,
-                                    epoch_num+1,
-                                    save_conf_matrix=False
-                                    )
-
-                
-# ------- Start the training -------   
-
-print('\nModel: ', args.transformer_model)
-print('Batch size: ', args.batch_size * args.gradient_accumulation_steps)
-print('\nStart training...\n')
-train(model, 
-      train_dataloader,
-      dev_dict_dataloader, 
-      test_dict_sentences, 
-      test_dict_dataloader,
-      args.num_epochs
-     )
-print('\nTraining Done!')
-
+                model,
+                corpus,
+                test_dict_dataloader[corpus],
+                framework_labels[corpus.split(".")[1]],
+                inv_mappings,
+                epoch_num + 1,
+                save_conf_matrix=False,
+            )
+
+
+# ------- Start the training -------
+
+print("\nModel: ", args.transformer_model)
+print("Batch size: ", args.batch_size * args.gradient_accumulation_steps)
+print("\nStart training...\n")
+train(
+    model,
+    train_dataloader,
+    dev_dict_dataloader,
+    test_dict_sentences,
+    test_dict_dataloader,
+    args.num_epochs,
+)
+print("\nTraining Done!")
diff --git a/configure.py b/configure.py
index 10de0b5..a6cca0e 100644
--- a/configure.py
+++ b/configure.py
@@ -1,57 +1,96 @@
 import argparse
 import sys
 
+
 def parse_args():
     """
     Parse input arguments.
     """
     parser = argparse.ArgumentParser()
-    
+
     # path to data
-    parser.add_argument("--data_path", default="./data", type=str, 
-                        help="The path to the shared task data file from Github.")
-    
+    parser.add_argument(
+        "--data_path",
+        default="./data",
+        type=str,
+        help="The path to the shared task data file from Github.",
+    )
+
     # label mappings to integers
-    parser.add_argument("--mappings_file", default="mappings/mappings_substitutions.tsv", type=str, 
-                        help="The mappings file for all relations.")
+    parser.add_argument(
+        "--mappings_file",
+        default="mappings/mappings_substitutions.tsv",
+        type=str,
+        help="The mappings file for all relations.",
+    )
 
     # transformer model
-    parser.add_argument("--transformer_model", default="bert-base-multilingual-cased", type=str, 
-                        help="Model used, default: bert-multilingual-base-cased")
+    parser.add_argument(
+        "--transformer_model",
+        default="bert-base-multilingual-cased",
+        type=str,
+        help="Model used, default: bert-multilingual-base-cased",
+    )
 
     # Number of training epochs
-    parser.add_argument("--num_epochs", default=10, type=int, 
-                        help="Number of training epochs. Default: 10")
-    
+    parser.add_argument(
+        "--num_epochs",
+        default=10,
+        type=int,
+        help="Number of training epochs. Default: 10",
+    )
+
     # Number of gradient accumulation steps
-    parser.add_argument("--gradient_accumulation_steps", default=16, type=int, 
-                        help="Number of gradient accumulation steps. Default: 16")
-    
+    parser.add_argument(
+        "--gradient_accumulation_steps",
+        default=16,
+        type=int,
+        help="Number of gradient accumulation steps. Default: 16",
+    )
+
     # Dropout
-    parser.add_argument("--dropout", default=0.1, type=float, 
-                        help="Dropout.")
-    
+    parser.add_argument("--dropout", default=0.1, type=float, help="Dropout.")
+
     # Batch size
-    parser.add_argument("--batch_size", default=8, type=int, 
-                        help="With CUDA: max. 8, without: max. 16. Default: 8")
-    
+    parser.add_argument(
+        "--batch_size",
+        default=8,
+        type=int,
+        help="With CUDA: max. 8, without: max. 16. Default: 8",
+    )
+
     # Use CUDA
-    parser.add_argument("--use_cuda", default='yes', type=str, 
-                        help="Use CUDA [yes/no]. Careful of batch size!")   
-    
+    parser.add_argument(
+        "--use_cuda",
+        default="yes",
+        type=str,
+        help="Use CUDA [yes/no]. Careful of batch size!",
+    )
+
     # freeze layers
-    parser.add_argument("--freeze_layers", default='', type=str, 
-                        help="List of layer(s) to freeze, a str separated by ;. Example: 'layer.1;layer.2'")   
-       
+    parser.add_argument(
+        "--freeze_layers",
+        default="",
+        type=str,
+        help="List of layer(s) to freeze, a str separated by ;. Example: 'layer.1;layer.2'",
+    )
+
     # normalize direction
-    parser.add_argument("--normalize_direction", default='yes', type=str, 
-                        help="Change order of sentences when the direction of relations is 1<2 to 2>1.") 
-    
+    parser.add_argument(
+        "--normalize_direction",
+        default="yes",
+        type=str,
+        help="Change order of sentences when the direction of relations is 1<2 to 2>1.",
+    )
+
     # only specific languages/corpora
-    parser.add_argument("--langs_to_use", default='@', type=str, 
-                        help="List of languages/corpora to use, a str separated by ;")   
-    
-            
+    parser.add_argument(
+        "--langs_to_use",
+        default="@",
+        type=str,
+        help="List of languages/corpora to use, a str separated by ;",
+    )
+
     args = parser.parse_args()
 
     return args
diff --git a/make_mappings_zero-shot.py b/make_mappings_zero-shot.py
index e433935..a740e98 100644
--- a/make_mappings_zero-shot.py
+++ b/make_mappings_zero-shot.py
@@ -11,13 +11,13 @@ args = parse_args()
 # -----------------------------------
 # open substitutions per file
 mappings = {}
-with open('mappings/mappings_substitutions.tsv', 'r', encoding='utf-8') as f:
+with open("mappings/mappings_substitutions.tsv", "r", encoding="utf-8") as f:
     next(f)
     for line in f:
-        l = line.strip().split('\t')
+        l = line.strip().split("\t")
         mappings[l[0]] = l[1]
 
-        
+
 # find the labels that were changed
 inv_mappings = {}
 subs = {}
@@ -26,31 +26,36 @@ for label, num in mappings.items():
         inv_mappings[num] = label
     else:
         subs[label] = inv_mappings[num]
-        
+
 
 # -----------------------------------
 # define which language to use with the arguments
-languages = args.langs_to_use.split(';')
+languages = args.langs_to_use.split(";")
 
 
-corpora = [folder 
-           for folder in os.listdir(args.data_path) 
-           if any(l in folder for l in languages)]
+corpora = [
+    folder
+    for folder in os.listdir(args.data_path)
+    if any(l in folder for l in languages)
+]
+
+files = [
+    "/".join([args.data_path, corpus, f])
+    for corpus in corpora
+    for f in os.listdir(args.data_path + "/" + corpus)
+]
 
-files = ['/'.join([args.data_path, corpus, f])
-         for corpus in corpora
-         for f in os.listdir(args.data_path + '/' + corpus)]
 
 # open the files
 def read_file(file):
-    ''' Open the relations file. '''
+    """Open the relations file."""
     relations = []
     sub_rels = []
-    with open(file, 'r', encoding='utf-8') as f:
+    with open(file, "r", encoding="utf-8") as f:
         next(f)
         for line in f:
             try:
-                l = line.strip().split('\t')
+                l = line.strip().split("\t")
                 if not l[11].lower() in subs:
                     relations.append(l[11].lower())
                 else:
@@ -60,8 +65,7 @@ def read_file(file):
         return relations, sub_rels
 
 
-rel_files = [f for f in files if any (x in f for x in ['train']
-                                     )]
+rel_files = [f for f in files if any(x in f for x in ["train"])]
 
 good_rels = []
 sub_rels = []
@@ -71,7 +75,7 @@ for f in rel_files:
     sub_rels += y
 
 dict_labels = dict(enumerate(list(set(good_rels))))
-corpora_labels = {v:k for k, v in dict_labels.items()}
+corpora_labels = {v: k for k, v in dict_labels.items()}
 
 
 leftovers = []
@@ -80,12 +84,12 @@ for sub in sub_rels:
     try:
         corpora_labels[sub] = corpora_labels[subs[sub]]
     except KeyError:
-        corpora_labels[subs[sub]] =  max(list(corpora_labels.values())) + 1
+        corpora_labels[subs[sub]] = max(list(corpora_labels.values())) + 1
         corpora_labels[sub] = corpora_labels[subs[sub]]
 
-corpora_labels['unk'] = max(list(corpora_labels.values())) + 1
-        
-with open('mappings/' + args.mappings_file, 'w') as f:
-    f.write('LABEL\tMAPPING\n')
+corpora_labels["unk"] = max(list(corpora_labels.values())) + 1
+
+with open("mappings/" + args.mappings_file, "w") as f:
+    f.write("LABEL\tMAPPING\n")
     for k, v in corpora_labels.items():
-        f.write(k + '\t' + str(v) + '\n')
+        f.write(k + "\t" + str(v) + "\n")
diff --git a/utils.py b/utils.py
index ff328ce..d5108e1 100644
--- a/utils.py
+++ b/utils.py
@@ -6,7 +6,12 @@ import torch
 from transformers import AutoConfig, AutoTokenizer
 from configure import parse_args
 import numpy as np
-from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, ConfusionMatrixDisplay
+from sklearn.metrics import (
+    accuracy_score,
+    confusion_matrix,
+    classification_report,
+    ConfusionMatrixDisplay,
+)
 import matplotlib.pyplot as plt
 import seaborn as sns
 from time import sleep
@@ -18,129 +23,170 @@ dt_string = now.strftime("%d.%m.%y-%H:%M:%S")
 
 args = parse_args()
 
+
 def open_mappings(mappings_file):
-    
-    ''' Open the mappings file into a dictionary.'''
-    
+    """Open the mappings file into a dictionary."""
+
     mappings = {}
-    with open(mappings_file, 'r') as f:
+    with open(mappings_file, "r") as f:
         next(f)
         for line in f:
-            l = line.strip().split('\t')
+            l = line.strip().split("\t")
             mappings[l[0]] = int(l[-1])
-           
+
     # reject the converted labels
     inv_mappings = {}
     for k, v in mappings.items():
         if v not in inv_mappings:
             inv_mappings[v] = k
-    
+
     return mappings, inv_mappings
 
 
 def encode_label(og_label, mappings_dict):
-    ''' Encode the label. '''
+    """Encode the label."""
 
     label = og_label.lower().strip()
     if label in mappings_dict:
         return mappings_dict[label]
     else:
-        return mappings_dict['unk']
-    
-
-def open_file(filename, mappings_dict):   
-    
-    ''' Function to open a .rels file. 
-        Arguments: 
-        - filename: the path to a .rels file 
-        - mappings_dict: a dictionary of mappings of unique labels to integers
-        Returns a list of lists, where each list is:
-        the line + [two sentences combined with special BERT token, encoded label]
-    '''
-    
-    max_len = 254 # 512 (max bert len) / 2 (2 sents) -2 (special tokens)
+        return mappings_dict["unk"]
+
+
+def open_file(filename, mappings_dict):
+    """Function to open a .rels file.
+    Arguments:
+    - filename: the path to a .rels file
+    - mappings_dict: a dictionary of mappings of unique labels to integers
+    Returns a list of lists, where each list is:
+    the line + [two sentences combined with special BERT token, encoded label]
+    """
+
+    max_len = 254  # 512 (max bert len) / 2 (2 sents) -2 (special tokens)
     lines = []
-    SEP_token = '[SEP]'
+    SEP_token = "[SEP]"
 
-    with open(filename, 'r', encoding='utf-8') as f:
+    with open(filename, "r", encoding="utf-8") as f:
         next(f)
         for line in f:
-            l = line.strip().split('\t')
-            
+            l = line.strip().split("\t")
+
             if len(l) > 1:
                 # chop the sentences to max_len if too long
-                sent_1 = l[3].split(' ')
-                sent_2 = l[4].split(' ')      
-                
+                sent_1 = l[3].split(" ")
+                sent_2 = l[4].split(" ")
+
                 if len(sent_1) > max_len:
                     sent_1 = sent_1[:max_len]
                 if len(sent_2) > max_len:
                     sent_2 = sent_2[:max_len]
-                
+
                 # flip them if different direction
-                if args.normalize_direction == 'yes':
-                    if l[9] == '1>2':
-                        lines.append(l + [sent_1 + [SEP_token] + sent_2, encode_label(l[-1], mappings_dict)])
+                if args.normalize_direction == "yes":
+                    if l[9] == "1>2":
+                        lines.append(
+                            l
+                            + [
+                                sent_1 + [SEP_token] + sent_2,
+                                encode_label(l[-1], mappings_dict),
+                            ]
+                        )
                     else:
-                        lines.append(l + [sent_2 + [SEP_token] + sent_1, encode_label(l[-1], mappings_dict)])
+                        lines.append(
+                            l
+                            + [
+                                sent_2 + [SEP_token] + sent_1,
+                                encode_label(l[-1], mappings_dict),
+                            ]
+                        )
                 else:
-                    lines.append(l + [sent_1 + [SEP_token] + sent_2, encode_label(l[-1], mappings_dict)])
+                    lines.append(
+                        l
+                        + [
+                            sent_1 + [SEP_token] + sent_2,
+                            encode_label(l[-1], mappings_dict),
+                        ]
+                    )
 
     return lines
 
 
-def open_file_with_lang(filename, mappings_dict):   
-    
-    ''' Same as above, but add the lcf toekns at the start of the sequence. '''
-    
-    max_len = 254 # 512 (max bert len) / 2 (2 sents) -2 (special tokens)
+def open_file_with_lang(filename, mappings_dict):
+    """Same as above, but add the lcf toekns at the start of the sequence."""
+
+    max_len = 254  # 512 (max bert len) / 2 (2 sents) -2 (special tokens)
     lines = []
-    SEP_token = '[SEP]'
-    
-    langs = {'deu':'German', 
-            'eng':'English',
-            'eus': 'Basque',
-            'fas':'Farsi',
-            'fra':'French', 
-            'ita':'Italian', 
-            'nld':'Dutch',
-            'por':'Portuguese', 
-            'rus': 'Russian', 
-            'spa': 'Spanish', 
-            'tur': 'Turkish',
-            'tha': 'Thai', 
-            'zho': 'Chinese'
-            }
-
-    with open(filename, 'r', encoding='utf-8') as f:
+    SEP_token = "[SEP]"
+
+    langs = {
+        "deu": "German",
+        "eng": "English",
+        "eus": "Basque",
+        "fas": "Farsi",
+        "fra": "French",
+        "ita": "Italian",
+        "nld": "Dutch",
+        "por": "Portuguese",
+        "rus": "Russian",
+        "spa": "Spanish",
+        "tur": "Turkish",
+        "tha": "Thai",
+        "zho": "Chinese",
+    }
+
+    with open(filename, "r", encoding="utf-8") as f:
         next(f)
-        
-        lang = langs[filename.split('/')[-2].split('.')[0]]
-        framework = filename.split('/')[-2].split('.')[1]
-        fullname = filename.split('/')[-2]
-        
+
+        lang = langs[filename.split("/")[-2].split(".")[0]]
+        framework = filename.split("/")[-2].split(".")[1]
+        fullname = filename.split("/")[-2]
+
         for line in f:
-            l = line.strip().split('\t')
-            
+            l = line.strip().split("\t")
+
             if len(l) > 1:
                 # chop the sentences to max_len if too long
-                sent_1 = l[3].split(' ')
-                sent_2 = l[4].split(' ')      
-                
+                sent_1 = l[3].split(" ")
+                sent_2 = l[4].split(" ")
+
                 if len(sent_1) > max_len:
                     sent_1 = sent_1[:max_len]
                 if len(sent_2) > max_len:
                     sent_2 = sent_2[:max_len]
-                
+
                 # flip them if different direction
-                if args.normalize_direction == 'yes':
-                    if l[9] == '1>2':
-                        #lang, fullname, framework
-                        lines.append(l + [[lang, fullname, framework] + sent_1 + [SEP_token] + sent_2, encode_label(l[11], mappings_dict)])
+                if args.normalize_direction == "yes":
+                    if l[9] == "1>2":
+                        # lang, fullname, framework
+                        lines.append(
+                            l
+                            + [
+                                [lang, fullname, framework]
+                                + sent_1
+                                + [SEP_token]
+                                + sent_2,
+                                encode_label(l[11], mappings_dict),
+                            ]
+                        )
                     else:
-                        lines.append(l + [[lang, fullname, framework] + sent_2 + [SEP_token] + sent_1, encode_label(l[11], mappings_dict)])
+                        lines.append(
+                            l
+                            + [
+                                [lang, fullname, framework]
+                                + sent_2
+                                + [SEP_token]
+                                + sent_1,
+                                encode_label(l[11], mappings_dict),
+                            ]
+                        )
                 else:
-                    lines.append(l + [[lang, fullname, framework] + sent_1 + [SEP_token] + sent_2, encode_label(l[11], mappings_dict)])
+                    lines.append(
+                        l
+                        + [
+                            [lang, fullname, framework] + sent_1 + [SEP_token] + sent_2,
+                            encode_label(l[11], mappings_dict),
+                        ]
+                    )
 
     return lines
 
@@ -149,325 +195,322 @@ def open_file_with_lang(filename, mappings_dict):
 # OPENING FILES FUNCTIONS
 # ===============
 
+
 def open_sentences(path_to_corpora, mappings_dict):
-    ''' Opens all the corpora and the surprise corpora in train/dev/test sets.
-        Uses the open_file() function from utils.
-        Returns:
-        - list of sentences for TRAIN: all the corpora and surprise corpora together
-        - dict of sentences for DEV: each dev set categorized per corpus
-        - dict of sentences for TEST: each test set categorized per corpus
-        - ** NEW ** : dict of labels per framework
-    '''
+    """Opens all the corpora and the surprise corpora in train/dev/test sets.
+    Uses the open_file() function from utils.
+    Returns:
+    - list of sentences for TRAIN: all the corpora and surprise corpora together
+    - dict of sentences for DEV: each dev set categorized per corpus
+    - dict of sentences for TEST: each test set categorized per corpus
+    - ** NEW ** : dict of labels per framework
+    """
     langs_to_use = False
-    
-    if args.langs_to_use != '@':
-        langs_to_use = args.langs_to_use.split(';')
-    
-    corpora = [folder for folder in os.listdir(path_to_corpora) 
-               if not any(i in folder for i in ['.md', 'DS_', 'utils', 'ipynb'])
-               ]
-               
+
+    if args.langs_to_use != "@":
+        langs_to_use = args.langs_to_use.split(";")
+
+    corpora = [
+        folder
+        for folder in os.listdir(path_to_corpora)
+        if not any(i in folder for i in [".md", "DS_", "utils", "ipynb"])
+    ]
+
     # ---------------------
-    train_sentences     = []
-    dev_dict_sentences  = {}
+    train_sentences = []
+    dev_dict_sentences = {}
     test_dict_sentences = {}
-    
+
     all_labels = {}
 
     for corpus in corpora:
-        framework = corpus.split('.')[-2]
+        framework = corpus.split(".")[-2]
         if not framework in all_labels:
             all_labels[framework] = []
-        
+
         # ===== open train ====
         try:
-            # open normal files   
-            
-            if langs_to_use: 
-            # if we only train with cetrain corpora, we only load them
-                train_file = ['/'.join([path_to_corpora, corpus, x])
-                                  for x in os.listdir(path_to_corpora + '/' + corpus) 
-                                  if 'train' in x and 'rels' in x
-                                  if any(l in x for l in langs_to_use)
-                             ][0]
+            # open normal files
+
+            if langs_to_use:
+                # if we only train with cetrain corpora, we only load them
+                train_file = [
+                    "/".join([path_to_corpora, corpus, x])
+                    for x in os.listdir(path_to_corpora + "/" + corpus)
+                    if "train" in x and "rels" in x
+                    if any(l in x for l in langs_to_use)
+                ][0]
             else:
-                train_file = train_file = os.path.join(args.data_path, corpus, corpus + '_train.rels')
+                train_file = train_file = os.path.join(
+                    args.data_path, corpus, corpus + "_train.rels"
+                )
             temp = open_file(train_file, mappings_dict)
             train_sentences += temp
             all_labels[framework] += [l[-1] for l in temp]
 
-        except: # some of them don't have train
+        except:  # some of them don't have train
             pass
-        
+
         # ======== open dev ========
         dev_dict_sentences[corpus] = []
-        dev_file = os.path.join(args.data_path, corpus, corpus + '_dev.rels')
+        dev_file = os.path.join(args.data_path, corpus, corpus + "_dev.rels")
         temp = open_file(dev_file, mappings_dict)
         dev_dict_sentences[corpus] += temp
         all_labels[framework] += [l[-1] for l in temp]
 
         # ======== open test ========
         test_dict_sentences[corpus] = []
-        test_file = os.path.join(args.data_path, corpus, corpus + '_test.rels')
+        test_file = os.path.join(args.data_path, corpus, corpus + "_test.rels")
         temp = open_file(test_file, mappings_dict)
         test_dict_sentences[corpus] += temp
-        all_labels[framework] += [l[-1] for l in temp]  
+        all_labels[framework] += [l[-1] for l in temp]
 
-    corpus_labels = {framework:set(all_labels[framework]) for framework in all_labels}
+    corpus_labels = {framework: set(all_labels[framework]) for framework in all_labels}
     # delete unk as a sanity check
     for framework in corpus_labels:
-        if 'unk' in corpus_labels[framework]:
-            corpus_labels[framework].remove('unk')
+        if "unk" in corpus_labels[framework]:
+            corpus_labels[framework].remove("unk")
 
     return train_sentences, dev_dict_sentences, test_dict_sentences, corpus_labels
 
 
 def open_sentences_with_lang(path_to_corpora, mappings_dict):
-    ''' Opens all the corpora and the surprise corpora in train/dev/test sets.
-        Uses the open_file() function from utils.
-        Returns:
-        - list of sentences for TRAIN: all the corpora and surprise corpora together
-        - dict of sentences for DEV: each dev set categorized per corpus
-        - dict of sentences for TEST: each test set categorized per corpus
-    '''
+    """Opens all the corpora and the surprise corpora in train/dev/test sets.
+    Uses the open_file() function from utils.
+    Returns:
+    - list of sentences for TRAIN: all the corpora and surprise corpora together
+    - dict of sentences for DEV: each dev set categorized per corpus
+    - dict of sentences for TEST: each test set categorized per corpus
+    """
     langs_to_use = False
-        
-    if args.langs_to_use != '@':
-        langs_to_use = args.langs_to_use.split(';')
-    
-    corpora = [folder for folder in os.listdir(path_to_corpora) 
-               if not any(i in folder for i in ['.md', 'DS_', 'utils', 'ipynb'])
-               ]
-               
+
+    if args.langs_to_use != "@":
+        langs_to_use = args.langs_to_use.split(";")
+
+    corpora = [
+        folder
+        for folder in os.listdir(path_to_corpora)
+        if not any(i in folder for i in [".md", "DS_", "utils", "ipynb"])
+    ]
+
     # ---------------------
-    train_sentences     = []
-    dev_dict_sentences  = {}
+    train_sentences = []
+    dev_dict_sentences = {}
     test_dict_sentences = {}
-    
+
     all_labels = {}
 
     for corpus in corpora:
-        framework = corpus.split('.')[-2]
+        framework = corpus.split(".")[-2]
         if not framework in all_labels:
             all_labels[framework] = []
-        
+
         # ===== open train ====
         try:
-            # open normal files   
-            if langs_to_use: 
-            # if we only train with cetrain corpora, we only load them
-                train_file = ['/'.join([path_to_corpora, corpus, x])
-                                  for x in os.listdir(path_to_corpora + '/' + corpus) 
-                                  if 'train' in x and 'rels' in x
-                                  if any(l in x for l in langs_to_use)
-                             ][0]
+            # open normal files
+            if langs_to_use:
+                # if we only train with cetrain corpora, we only load them
+                train_file = [
+                    "/".join([path_to_corpora, corpus, x])
+                    for x in os.listdir(path_to_corpora + "/" + corpus)
+                    if "train" in x and "rels" in x
+                    if any(l in x for l in langs_to_use)
+                ][0]
             else:
-                train_file = ['/'.join([path_to_corpora, corpus, x])
-                                  for x in os.listdir(path_to_corpora + '/' + corpus) 
-                                  if 'train' in x and 'rels' in x
-                             ][0]
+                train_file = [
+                    "/".join([path_to_corpora, corpus, x])
+                    for x in os.listdir(path_to_corpora + "/" + corpus)
+                    if "train" in x and "rels" in x
+                ][0]
             temp = open_file_with_lang(train_file, mappings_dict)
             train_sentences += temp
             all_labels[framework] += [l[-1] for l in temp]
-        except: # some of them don't have train
+        except:  # some of them don't have train
             pass
 
-        #open each test separately
+        # open each test separately
         dev_dict_sentences[corpus] = []
-        dev_file = ['/'.join([path_to_corpora,corpus,x])
-                              for x in os.listdir(path_to_corpora + '/' + corpus) 
-                              if 'dev' in x and 'rels' in x][0] 
+        dev_file = [
+            "/".join([path_to_corpora, corpus, x])
+            for x in os.listdir(path_to_corpora + "/" + corpus)
+            if "dev" in x and "rels" in x
+        ][0]
         temp = open_file_with_lang(dev_file, mappings_dict)
         dev_dict_sentences[corpus] += temp
-        all_labels[framework] += [l[-1] for l in temp]  
+        all_labels[framework] += [l[-1] for l in temp]
 
-        #open each test separately
+        # open each test separately
         test_dict_sentences[corpus] = []
-        test_file = ['/'.join([path_to_corpora,corpus,x])
-                              for x in os.listdir(path_to_corpora + '/' + corpus) 
-                              if 'test' in x and 'rels' in x][0] 
+        test_file = [
+            "/".join([path_to_corpora, corpus, x])
+            for x in os.listdir(path_to_corpora + "/" + corpus)
+            if "test" in x and "rels" in x
+        ][0]
         temp = open_file_with_lang(test_file, mappings_dict)
         test_dict_sentences[corpus] += temp
-        all_labels[framework] += [l[-1] for l in temp]  
+        all_labels[framework] += [l[-1] for l in temp]
 
-    corpus_labels = {framework:set(all_labels[framework]) for framework in all_labels}
+    corpus_labels = {framework: set(all_labels[framework]) for framework in all_labels}
     # delete unk as a sanity check
     for framework in corpus_labels:
-        if 'unk' in corpus_labels[framework]:
-            corpus_labels[framework].remove('unk')
-    
-    return train_sentences, dev_dict_sentences, test_dict_sentences, corpus_labels
+        if "unk" in corpus_labels[framework]:
+            corpus_labels[framework].remove("unk")
 
+    return train_sentences, dev_dict_sentences, test_dict_sentences, corpus_labels
 
 
 # ===============
 # Testing functions
 # ===============
 
-def get_predictions(model,
-                    corpus, 
-                    test_dataloader, 
-                    print_results=True):
-    
-    ''' Function to get the model's predictions for one corpus' test set.
-        Can print accuracy using scikit-learn.
-        Also works with dev sets -- just don't save the outputs.
-        Returns: list of predictions that match test file's lines.
-    '''
-    
+
+def get_predictions(model, corpus, test_dataloader, print_results=True):
+    """Function to get the model's predictions for one corpus' test set.
+    Can print accuracy using scikit-learn.
+    Also works with dev sets -- just don't save the outputs.
+    Returns: list of predictions that match test file's lines.
+    """
+
     device = torch.device("cuda" if args.use_cuda else "cpu")
 
     if args.use_cuda:
         model = model.cuda()
-    
+
     model.eval()
     test_loss, test_accuracy = 0, 0
 
     all_labels = []
     all_preds = []
-    
+
     with torch.no_grad():
         for test_input, test_label in test_dataloader:
-
-            mask = test_input['attention_mask'].to(device)
-            input_id = test_input['input_ids'].squeeze(1).to(device)
+            mask = test_input["attention_mask"].to(device)
+            input_id = test_input["input_ids"].squeeze(1).to(device)
             output = model(input_id, mask)
 
             logits = output[0]
             logits = logits.detach().cpu().numpy()
-            label_ids = test_label.to('cpu').numpy()
+            label_ids = test_label.to("cpu").numpy()
 
             all_labels += label_ids.tolist()
             all_preds += output.argmax(dim=1).tolist()
 
         assert len(all_labels) == len(all_preds)
         test_acc = round(accuracy_score(all_labels, all_preds), 4)
-    
+
     if print_results:
-        print(corpus, '\tAccuracy:\t', test_acc)
-    
+        print(corpus, "\tAccuracy:\t", test_acc)
+
     return all_preds
-    
-
-def make_confusion_matrices(y_test,
-                            y_pred, 
-                            corpus_name,
-                            inv_mappings,
-                            epoch):
-    
-    save_path = 'conf_matrix/' + dt_string
+
+
+def make_confusion_matrices(y_test, y_pred, corpus_name, inv_mappings, epoch):
+    save_path = "conf_matrix/" + dt_string
     if not os.path.exists(save_path):
         os.makedirs(save_path)
-        
-    print(classification_report(y_test,
-                                y_pred, 
-                               )
-         )
-    
-
-    cm = confusion_matrix(y_test,
-                          y_pred, 
-                          labels = list(inv_mappings.keys())
-                         )
+
+    print(
+        classification_report(
+            y_test,
+            y_pred,
+        )
+    )
+
+    cm = confusion_matrix(y_test, y_pred, labels=list(inv_mappings.keys()))
     print(cm)
-    
+
     xticklabels = list(inv_mappings.values())
-    yticklabels = list(inv_mappings.values())    
-    
+    yticklabels = list(inv_mappings.values())
+
     sns.color_palette("cubehelix", as_cmap=True)
     # Plot the confusion matrix.
-    
+
     fig, ax = plt.subplots()
-#     ax.tick_params(axis='both', which='major', labelsize=6)
-#     ax.tick_params(axis='both', which='minor', labelsize=6)
-    ax = sns.heatmap(cm,
-                #annot=Truex
-                xticklabels=xticklabels, 
-                yticklabels=yticklabels
-               )
-    plt.ylabel('Predicted label')
-    plt.xlabel('Corpus label')
+    #     ax.tick_params(axis='both', which='major', labelsize=6)
+    #     ax.tick_params(axis='both', which='minor', labelsize=6)
+    ax = sns.heatmap(
+        cm,
+        # annot=Truex
+        xticklabels=xticklabels,
+        yticklabels=yticklabels,
+    )
+    plt.ylabel("Predicted label")
+    plt.xlabel("Corpus label")
     plt.xticks(fontsize=2)
     plt.yticks(fontsize=2)
-#     plt.xticks(x, labels, rotation='vertical')
-#     plt.margins(0.5)
+    #     plt.xticks(x, labels, rotation='vertical')
+    #     plt.margins(0.5)
     plt.subplots_adjust(bottom=0.5, left=0.5)
-    plt.title('Confusion Matrix: '+corpus_name+' (epoch:'+ str(epoch) + ')')
-    plt.savefig(save_path + '/' + corpus_name + '_' + str(epoch) + '.png', 
-                dpi=300)
+    plt.title("Confusion Matrix: " + corpus_name + " (epoch:" + str(epoch) + ")")
+    plt.savefig(save_path + "/" + corpus_name + "_" + str(epoch) + ".png", dpi=300)
     plt.clf()
 
 
-def get_better_predictions(model,
-                            corpus, 
-                            test_dataloader, 
-                            corpus_labels,
-                            inv_mappings,
-                            epoch,
-                            print_results=True, 
-                            save_conf_matrix=False):
-    
+def get_better_predictions(
+    model,
+    corpus,
+    test_dataloader,
+    corpus_labels,
+    inv_mappings,
+    epoch,
+    print_results=True,
+    save_conf_matrix=False,
+):
     device = torch.device("cuda" if args.use_cuda else "cpu")
 
     if args.use_cuda:
         model = model.cuda()
-    
+
     model.eval()
     all_labels = []
     all_preds = []
-    
+
     with torch.no_grad():
         for test_input, test_label in test_dataloader:
-
-            mask = test_input['attention_mask'].to(device)
-            input_id = test_input['input_ids'].squeeze(1).to(device)
+            mask = test_input["attention_mask"].to(device)
+            input_id = test_input["input_ids"].squeeze(1).to(device)
             output = model(input_id, mask)
 
             logits = output[0]
             logits = logits.detach().cpu().numpy()
-            label_ids = test_label.to('cpu').numpy()
+            label_ids = test_label.to("cpu").numpy()
 
-            #all_labels += label_ids.tolist()
+            # all_labels += label_ids.tolist()
             batch_labels = label_ids.tolist()
             batch_probs = []
             for p in output.softmax(dim=-1).tolist():
                 batch_probs.append(dict(enumerate(p)))
-                
+
             for probs in batch_probs:
                 final_probs = {}
-                sorted_probs = dict(sorted(probs.items(), key=lambda item:item[1]))
+                sorted_probs = dict(sorted(probs.items(), key=lambda item: item[1]))
                 for pred_label in sorted_probs:
                     if pred_label in corpus_labels:
                         final_probs[pred_label] = sorted_probs[pred_label]
-                        
+
                 all_preds += [final_probs]
-            
+
             all_labels += batch_labels
-            
+
     # get the top predictions in order to get the acc
-    
+
     top_preds = []
     for probs in all_preds:
         top_preds.append(max(zip(probs.values(), probs.keys()))[1])
     test_acc = round(accuracy_score(all_labels, top_preds), 4)
-    
+
     if print_results:
-        print('After label filtering:\t' + str(test_acc), flush='True')
-        
+        print("After label filtering:\t" + str(test_acc), flush="True")
+
         print(classification_report(all_labels, top_preds))
-    
+
     if save_conf_matrix:
         try:
-            make_confusion_matrices(all_labels, 
-                                        top_preds, 
-                                        corpus,
-                                        inv_mappings,
-                                        epoch)
+            make_confusion_matrices(all_labels, top_preds, corpus, inv_mappings, epoch)
         except ValueError:
-            print('matrix failed to print')
+            print("matrix failed to print")
 
     print()
-    print('----')
-    
-    return all_labels, all_preds
+    print("----")
 
+    return all_labels, all_preds
-- 
GitLab