From c1a7f4718f4307067873ed0e34fde45e7445c9de Mon Sep 17 00:00:00 2001 From: emetheni <lenakmeth@gmail.com> Date: Sat, 13 May 2023 14:54:30 +0200 Subject: [PATCH] correct dictionaries --- adapter_classifier.py | 10 ++++++---- utils.py | 11 +++++++---- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/adapter_classifier.py b/adapter_classifier.py index 0144101..d7ca592 100644 --- a/adapter_classifier.py +++ b/adapter_classifier.py @@ -28,11 +28,13 @@ adapter_name = args.adapter_name mappings, inv_mappings = open_mappings(args.mappings_file) substitutions_file = 'mappings/substitutions.txt' tokenizer = AutoTokenizer.from_pretrained(args.transformer_model) + +# we are saving the test results of specific epochs specific_results = open_specific_results('mappings/specific_results.txt') -if '1-2-3' in adapter_name: - specific_results = specific_results['A1_3'] +if '1-2-3' in adapter_name or 'layer1;layer2;layer3' in adapter_name: + specific_results = specific_results['A1_3'][4] else: - specific_results = specific_results['A1'] + specific_results = specific_results['A1'][3] set_seed(42) @@ -147,6 +149,6 @@ for corpus in encoded_test_dataset: test_results = get_predictions_huggingface(trainer, corpus, encoded_test_dataset[corpus]) - if corpus in specific_results[args.num_epochs]: + if corpus in specific_results: print_results_to_file(corpus, test_dict_sentences[corpus], test_results, inv_mappings, substitutions_file) \ No newline at end of file diff --git a/utils.py b/utils.py index 9e069e9..fe1d5c5 100644 --- a/utils.py +++ b/utils.py @@ -230,10 +230,13 @@ def print_results_to_file(corpus, test_sentences, test_results, for n, sent in enumerate(test_sentences): label = test_results[n] label = inv_mappings_dict[label] - if corpus in revert_substitutions: - if label in revert_substitutions[corpus]: - label = revert_substitutions[corpus][label] - temp = sent[:-3] + [label] + try: + if corpus in revert_substitutions: + if label in revert_substitutions[corpus]: + label = revert_substitutions[corpus][label] + except: + pass + temp = sent[:-2] + [label] assert len(temp) == 12 results_to_write.append(temp) -- GitLab