From f3321b82ab8f10c2a7e5a845941493230b6fdb82 Mon Sep 17 00:00:00 2001
From: emetheni <lenakmeth@gmail.com>
Date: Fri, 12 May 2023 23:50:51 +0200
Subject: [PATCH] correct generators

---
 adapter_classifier.py | 24 ++++++++++++++++--------
 utils.py              |  3 ++-
 2 files changed, 18 insertions(+), 9 deletions(-)

diff --git a/adapter_classifier.py b/adapter_classifier.py
index 478ae6e..0144101 100644
--- a/adapter_classifier.py
+++ b/adapter_classifier.py
@@ -50,19 +50,27 @@ train_sentences, dev_dict_sentences, test_dict_sentences = open_sentences(args.d
 
 # make pandas dataframes
 file_header = ['text', 'labels']
-train_df = pd.DataFrame([[' '.join(x[-2]), x[-1]] for x in train_sentences], columns =file_header)
+
+train_df = pd.DataFrame([[' '.join(x[-2]), x[-1]] for x in train_sentences], 
+                        columns =file_header)
 train_df = train_df.sample(frac = 1) # shuffle the train
-# get a global dev accuracy, we will not be directly using these results
-dev_dict_df = {corpus:pd.DataFrame([[' '.join(x[-2]), x[-1]] for x in sents], columns = file_header) 
-                for corpus, sents in dev_dict_sentences.items()}
 
-test_dict_df  = {corpus:pd.DataFrame([[' '.join(x[-2]), x[-1]] for x in sents], columns = file_header) 
-                for corpus, sents in test_dict_sentences.items()}
+dev_dict_df = {corpus : pd.DataFrame([[' '.join(x[-2]), x[-1]] 
+                                      for x in sents], 
+                                     columns = file_header)
+               for corpus, sents in dev_dict_sentences.items()}
+
+test_dict_df = {corpus : pd.DataFrame([[' '.join(x[-2]), x[-1]] 
+                                      for x in sents], 
+                                     columns = file_header)
+               for corpus, sents in test_dict_sentences.items()}
 
 #Make datasets from dataframes
 train_dataset = datasets.Dataset.from_pandas(train_df)
-dev_dict_dataset  = {corpus:datasets.Dataset.from_pandas(dev_df) for dev_df in dev_dict_df}
-test_dict_dataset = {corpus:datasets.Dataset.from_pandas(dev_df) for dev_df in dev_dict_df}
+dev_dict_dataset  = {corpus:datasets.Dataset.from_pandas(dev_df) 
+                     for corpus, dev_df in dev_dict_df.items()}
+test_dict_dataset = {corpus:datasets.Dataset.from_pandas(dev_df) 
+                     for corpus, dev_df in test_dict_df.items()}
 
 # get number of labels
 num_labels = len(set([int(x.strip()) 
diff --git a/utils.py b/utils.py
index f9d610a..9e069e9 100644
--- a/utils.py
+++ b/utils.py
@@ -231,7 +231,8 @@ def print_results_to_file(corpus, test_sentences, test_results,
         label = test_results[n]
         label = inv_mappings_dict[label]
         if corpus in revert_substitutions:
-            label = revert_substitutions[corpus][label]
+            if label in revert_substitutions[corpus]:
+                label = revert_substitutions[corpus][label]
         temp  = sent[:-3] + [label]
         assert len(temp) == 12
         results_to_write.append(temp)
-- 
GitLab