From c1a7f4718f4307067873ed0e34fde45e7445c9de Mon Sep 17 00:00:00 2001
From: emetheni <lenakmeth@gmail.com>
Date: Sat, 13 May 2023 14:54:30 +0200
Subject: [PATCH] correct dictionaries

---
 adapter_classifier.py | 10 ++++++----
 utils.py              | 11 +++++++----
 2 files changed, 13 insertions(+), 8 deletions(-)

diff --git a/adapter_classifier.py b/adapter_classifier.py
index 0144101..d7ca592 100644
--- a/adapter_classifier.py
+++ b/adapter_classifier.py
@@ -28,11 +28,13 @@ adapter_name = args.adapter_name
 mappings, inv_mappings = open_mappings(args.mappings_file)
 substitutions_file = 'mappings/substitutions.txt'
 tokenizer = AutoTokenizer.from_pretrained(args.transformer_model)
+
+# we are saving the test results of specific epochs
 specific_results = open_specific_results('mappings/specific_results.txt')
-if '1-2-3' in adapter_name:
-    specific_results = specific_results['A1_3']
+if '1-2-3' in adapter_name or 'layer1;layer2;layer3' in adapter_name:
+    specific_results = specific_results['A1_3'][4]
 else:
-    specific_results = specific_results['A1']
+    specific_results = specific_results['A1'][3]
 
 set_seed(42)
 
@@ -147,6 +149,6 @@ for corpus in encoded_test_dataset:
     test_results = get_predictions_huggingface(trainer, corpus, 
                                     encoded_test_dataset[corpus])
     
-    if corpus in specific_results[args.num_epochs]:
+    if corpus in specific_results:
         print_results_to_file(corpus, test_dict_sentences[corpus], test_results, 
                               inv_mappings, substitutions_file)
\ No newline at end of file
diff --git a/utils.py b/utils.py
index 9e069e9..fe1d5c5 100644
--- a/utils.py
+++ b/utils.py
@@ -230,10 +230,13 @@ def print_results_to_file(corpus, test_sentences, test_results,
     for n, sent in enumerate(test_sentences):
         label = test_results[n]
         label = inv_mappings_dict[label]
-        if corpus in revert_substitutions:
-            if label in revert_substitutions[corpus]:
-                label = revert_substitutions[corpus][label]
-        temp  = sent[:-3] + [label]
+        try:
+            if corpus in revert_substitutions:
+                if label in revert_substitutions[corpus]:
+                    label = revert_substitutions[corpus][label]
+        except:
+            pass
+        temp  = sent[:-2] + [label]
         assert len(temp) == 12
         results_to_write.append(temp)
     
-- 
GitLab