Skip to content
Snippets Groups Projects
Commit c1a7f471 authored by emetheni's avatar emetheni
Browse files

correct

dictionaries
parent f3321b82
Branches
No related tags found
No related merge requests found
...@@ -28,11 +28,13 @@ adapter_name = args.adapter_name ...@@ -28,11 +28,13 @@ adapter_name = args.adapter_name
mappings, inv_mappings = open_mappings(args.mappings_file) mappings, inv_mappings = open_mappings(args.mappings_file)
substitutions_file = 'mappings/substitutions.txt' substitutions_file = 'mappings/substitutions.txt'
tokenizer = AutoTokenizer.from_pretrained(args.transformer_model) tokenizer = AutoTokenizer.from_pretrained(args.transformer_model)
# we are saving the test results of specific epochs
specific_results = open_specific_results('mappings/specific_results.txt') specific_results = open_specific_results('mappings/specific_results.txt')
if '1-2-3' in adapter_name: if '1-2-3' in adapter_name or 'layer1;layer2;layer3' in adapter_name:
specific_results = specific_results['A1_3'] specific_results = specific_results['A1_3'][4]
else: else:
specific_results = specific_results['A1'] specific_results = specific_results['A1'][3]
set_seed(42) set_seed(42)
...@@ -147,6 +149,6 @@ for corpus in encoded_test_dataset: ...@@ -147,6 +149,6 @@ for corpus in encoded_test_dataset:
test_results = get_predictions_huggingface(trainer, corpus, test_results = get_predictions_huggingface(trainer, corpus,
encoded_test_dataset[corpus]) encoded_test_dataset[corpus])
if corpus in specific_results[args.num_epochs]: if corpus in specific_results:
print_results_to_file(corpus, test_dict_sentences[corpus], test_results, print_results_to_file(corpus, test_dict_sentences[corpus], test_results,
inv_mappings, substitutions_file) inv_mappings, substitutions_file)
\ No newline at end of file
...@@ -230,10 +230,13 @@ def print_results_to_file(corpus, test_sentences, test_results, ...@@ -230,10 +230,13 @@ def print_results_to_file(corpus, test_sentences, test_results,
for n, sent in enumerate(test_sentences): for n, sent in enumerate(test_sentences):
label = test_results[n] label = test_results[n]
label = inv_mappings_dict[label] label = inv_mappings_dict[label]
if corpus in revert_substitutions: try:
if label in revert_substitutions[corpus]: if corpus in revert_substitutions:
label = revert_substitutions[corpus][label] if label in revert_substitutions[corpus]:
temp = sent[:-3] + [label] label = revert_substitutions[corpus][label]
except:
pass
temp = sent[:-2] + [label]
assert len(temp) == 12 assert len(temp) == 12
results_to_write.append(temp) results_to_write.append(temp)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment