Skip to content
Snippets Groups Projects
Commit c1a7f471 authored by emetheni's avatar emetheni
Browse files

correct

dictionaries
parent f3321b82
Branches
No related tags found
No related merge requests found
......@@ -28,11 +28,13 @@ adapter_name = args.adapter_name
mappings, inv_mappings = open_mappings(args.mappings_file)
substitutions_file = 'mappings/substitutions.txt'
tokenizer = AutoTokenizer.from_pretrained(args.transformer_model)
# we are saving the test results of specific epochs
specific_results = open_specific_results('mappings/specific_results.txt')
if '1-2-3' in adapter_name:
specific_results = specific_results['A1_3']
if '1-2-3' in adapter_name or 'layer1;layer2;layer3' in adapter_name:
specific_results = specific_results['A1_3'][4]
else:
specific_results = specific_results['A1']
specific_results = specific_results['A1'][3]
set_seed(42)
......@@ -147,6 +149,6 @@ for corpus in encoded_test_dataset:
test_results = get_predictions_huggingface(trainer, corpus,
encoded_test_dataset[corpus])
if corpus in specific_results[args.num_epochs]:
if corpus in specific_results:
print_results_to_file(corpus, test_dict_sentences[corpus], test_results,
inv_mappings, substitutions_file)
\ No newline at end of file
......@@ -230,10 +230,13 @@ def print_results_to_file(corpus, test_sentences, test_results,
for n, sent in enumerate(test_sentences):
label = test_results[n]
label = inv_mappings_dict[label]
if corpus in revert_substitutions:
if label in revert_substitutions[corpus]:
label = revert_substitutions[corpus][label]
temp = sent[:-3] + [label]
try:
if corpus in revert_substitutions:
if label in revert_substitutions[corpus]:
label = revert_substitutions[corpus][label]
except:
pass
temp = sent[:-2] + [label]
assert len(temp) == 12
results_to_write.append(temp)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment