diff --git a/adapter_classifier.py b/adapter_classifier.py index 0144101446e360c370bdc500df62e95561740c3b..d7ca592fafee501fc8829c932dd6a03732c923f8 100644 --- a/adapter_classifier.py +++ b/adapter_classifier.py @@ -28,11 +28,13 @@ adapter_name = args.adapter_name mappings, inv_mappings = open_mappings(args.mappings_file) substitutions_file = 'mappings/substitutions.txt' tokenizer = AutoTokenizer.from_pretrained(args.transformer_model) + +# we are saving the test results of specific epochs specific_results = open_specific_results('mappings/specific_results.txt') -if '1-2-3' in adapter_name: - specific_results = specific_results['A1_3'] +if '1-2-3' in adapter_name or 'layer1;layer2;layer3' in adapter_name: + specific_results = specific_results['A1_3'][4] else: - specific_results = specific_results['A1'] + specific_results = specific_results['A1'][3] set_seed(42) @@ -147,6 +149,6 @@ for corpus in encoded_test_dataset: test_results = get_predictions_huggingface(trainer, corpus, encoded_test_dataset[corpus]) - if corpus in specific_results[args.num_epochs]: + if corpus in specific_results: print_results_to_file(corpus, test_dict_sentences[corpus], test_results, inv_mappings, substitutions_file) \ No newline at end of file diff --git a/utils.py b/utils.py index 9e069e9930c663467f5934aef6a785557e9c6595..fe1d5c5b240896eb55f08f38ac427c167c6fdd5b 100644 --- a/utils.py +++ b/utils.py @@ -230,10 +230,13 @@ def print_results_to_file(corpus, test_sentences, test_results, for n, sent in enumerate(test_sentences): label = test_results[n] label = inv_mappings_dict[label] - if corpus in revert_substitutions: - if label in revert_substitutions[corpus]: - label = revert_substitutions[corpus][label] - temp = sent[:-3] + [label] + try: + if corpus in revert_substitutions: + if label in revert_substitutions[corpus]: + label = revert_substitutions[corpus][label] + except: + pass + temp = sent[:-2] + [label] assert len(temp) == 12 results_to_write.append(temp)