diff --git a/classifier_bare_huggingface.py b/classifier_bare_huggingface.py index 5b60babafe421c7236832bb2ae021509fd188ec2..4824ebccb8ee91f3a6820f0a4dc74c219bb0e664 100644 --- a/classifier_bare_huggingface.py +++ b/classifier_bare_huggingface.py @@ -152,8 +152,11 @@ trainer.train() print('\nDev results:') for corpus in encoded_dev_dataset: print() - dev_results = get_predictions_huggingface(trainer, corpus, - encoded_dev_dataset[corpus]) + dev_results = better_predictions_huggingface(trainer, + corpus, + encoded_dev_dataset[corpus], + framework_labels[corpus.split('.')[1]] + ) print(dev_results) @@ -171,17 +174,18 @@ for corpus in encoded_dev_dataset: # Test results -print('\ntest results:') -for corpus in encoded_test_dataset: - print() - test_results = get_predictions_huggingface(trainer, - corpus, - encoded_test_dataset[corpus]) +# print('\ntest results:') +# for corpus in encoded_test_dataset: +# print() +# test_results = get_predictions_huggingface(trainer, +# corpus, +# framework_labels[corpus.split('.')[1]], +# encoded_test_dataset[corpus]) - path_results = 'results/test/' + args.transformer_model + '_' + str(args.num_epochs) - if not os.path.exists(path_results): - os.makedirs(path_results) +# path_results = 'results/test/' + args.transformer_model + '_' + str(args.num_epochs) +# if not os.path.exists(path_results): +# os.makedirs(path_results) # print_results_to_file(corpus, # test_dict_sentences[corpus], diff --git a/make_mappings_zero-shot.py b/make_mappings_zero-shot.py index cacad68a8ccd94ee3f8c6622a00259cbd69c3d94..9f57d2ce825d5ccb0605c6b9c45ef6aac9ca2c81 100644 --- a/make_mappings_zero-shot.py +++ b/make_mappings_zero-shot.py @@ -72,7 +72,7 @@ for f in rel_files: dict_labels = dict(enumerate(list(set(good_rels)))) corpora_labels = {v:k for k, v in dict_labels.items()} -corpora_labels['unk'] = len(corpora_labels) + leftovers = [] @@ -90,9 +90,11 @@ for sub in sub_rels: try: corpora_labels[sub] = corpora_labels[subs[sub]] except KeyError: - corpora_labels[subs[sub]] = max(list(corpora_labels.values())) + 1 + corpora_labels[subs[sub]] = max(list(corpora_labels.values())) + 1 corpora_labels[sub] = corpora_labels[subs[sub]] +corpora_labels['unk'] = max(list(corpora_labels.values())) + 1 + # print(corpora_labels) with open(args.mappings_file, 'w') as f: diff --git a/mappings/mappings_final.tsv b/mappings/mappings_final.tsv new file mode 100644 index 0000000000000000000000000000000000000000..81ff02718d9a452dd2e0d3e0e77a5c84bc7329b7 --- /dev/null +++ b/mappings/mappings_final.tsv @@ -0,0 +1,158 @@ +LABEL MAPPING +result 0 +organization-phatic 1 +expansion.correction 2 +expansion.equivalence 3 +volitional-cause 4 +causal-result 5 +preparation 6 +summary 7 +elaboration-additional 8 +attribution 9 +temporal.asynchronous 10 +reason 11 +enablement 12 +restatement-repetition 13 +nonvolitional-cause-e 14 +contingency.purpose 15 +solutionhood 16 +topic-solutionhood 17 +temporal 18 +expansion.instantiation 19 +progression 20 +purpose-goal 21 +comment 22 +narration 23 +evaluation-comment 24 +conclusion 25 +explanation* 26 +contingency.cause 27 +temporal.synchronous 28 +comparison.concession 29 +contingency.cause+speechact 30 +continuation 31 +organization-heading 32 +acknowledgement 33 +circumstance 34 +nonvolitional-result 35 +contingency.condition+speechact 36 +temploc 37 +elab-addition 38 +elab-enumember 39 +effect 40 +restatement-mn 41 +condition 42 +topic-comment 43 +nonvolitional-cause 44 +joint-list 45 +concession 46 +clarification_question 47 +antithesis 48 +attribution-negative 49 +parenthetical 50 +purpose-attribute 51 +expansion.manner 52 +explanation-motivation 53 +q_elab 54 +context-background 55 +contingency.negative-cause 56 +conjunction 57 +frame 58 +contingency.negative 59 +bg-goal 60 +expansion.substitution 61 +nonvolitional-result-e 62 +contingency.condition 63 +volitional-result 64 +hypophora 65 +comparison.contrast 66 +adversative-antithesis 67 +joint-other 68 +mode-manner 69 +contingency-condition 70 +elab-example 71 +contingency.negative-condition+speechact 72 +joint-sequence 73 +expansion.level-of-detail 74 +explanation-evidence 75 +restatement-partial 76 +elab-aspect 77 +expansion.exception 78 +exp-reason 79 +evaluation-n 80 +topic-question 81 +parallel 82 +contingency.negative-condition 83 +evaluation 84 +attribution-positive 85 +topic-change 86 +bg-compare 87 +joint 88 +expansion.disjunction 89 +interpretation-evaluation 90 +elab-definition 91 +context-circumstance 92 +adversative-contrast 93 +causal-cause 94 +topic-drift 95 +elaboration 96 +expansion.restatement 97 +contingency.goal 98 +manner-means 99 +background 100 +mode-means 101 +comparison.similarity 102 +means 103 +comparison 104 +flashback 105 +interpretation 106 +explanation-justify 107 +organization-preparation 108 +elaboration-attribute 109 +contrast 110 +adversative-concession 111 +comparison.degree 112 +root 113 +e-elaboration 114 +expansion.alternative 115 +evaluation-s 116 +comparison.concession+speechact 117 +contingency.cause+belief 118 +purpose 119 +interrupted 120 +expansion 121 +cause 122 +question_answer_pair 123 +elab-process_step 124 +explanation 125 +cause-effect 126 +unk 127 +list 45 +evidence 75 +sequence 73 +disjunction 89 +justify 107 +motivation 53 +restatement 97 +expansion.genexpansion 121 +expansion.conjunction 57 +joint-disjunction 89 +goal 21 +alternation 115 +conditional 42 +adversative 128 +otherwise 128 +correction 2 +unconditional 89 +unless 110 +bg-general 100 +exp-evidence 75 +organization 129 +textual-organization 129 +alternative 115 +temporal.synchrony 28 +repetition 13 +expansion.level 74 +qap.hypophora 65 +qap 123 +causation 122 diff --git a/utils.py b/utils.py index 0e13b338804381cadccc829c381745955127c9c0..e3b17afaa7222453307a17f7249bcbdfe62f2228 100644 --- a/utils.py +++ b/utils.py @@ -142,11 +142,11 @@ def open_file_with_lang(filename, mappings_dict): # flip them if different direction if args.normalize_direction == 'yes': if l[9] == '1>2': - lines.append(l + [[lang, fullname, framework] + sent_1 + [SEP_token] + sent_2, encode_label(l[11], mappings_dict)]) + lines.append(l + [[lang, fullname] + sent_1 + [SEP_token] + sent_2, encode_label(l[11], mappings_dict)]) else: - lines.append(l + [[lang, fullname, framework] + sent_2 + [SEP_token] + sent_1, encode_label(l[11], mappings_dict)]) + lines.append(l + [[lang, fullname] + sent_2 + [SEP_token] + sent_1, encode_label(l[11], mappings_dict)]) else: - lines.append(l + [[lang, fullname, framework] + sent_1 + [SEP_token] + sent_2, encode_label(l[11], mappings_dict)]) + lines.append(l + [[lang, fullname] + sent_1 + [SEP_token] + sent_2, encode_label(l[11], mappings_dict)]) return lines @@ -391,6 +391,45 @@ def get_predictions_huggingface(trainer, return preds + +def better_predictions_huggingface(trainer, + corpus, + test_set, + corpus_labels, + print_results=True): + + ''' SPECIFI FUNCTION FOR THE HUGGINGFACE TRAINER. + Function to get the model's predictions for one corpus' test set. + Can print accuracy using scikit-learn. + Also works with dev sets -- just don't save the outputs. + Returns: list of predictions that match test file's lines. + ''' + + results = trainer.predict(test_set) + preds = np.argmax(results.predictions, axis=1) + orig_labels = results.label_ids + test_acc = round(accuracy_score(top_preds, orig_labels), 4) + + if print_results: + print(corpus + '\t' + str(test_acc) + '\n', flush='True') + + print(type(results.predictions)) + # try to make the better prediction bit + best_labels = [] + for n, result in enumerate(results.predictions.tolist()): + orig_label = results.label_ids[n] + best_prob = -1000 + best_label = -1 + if orig_label in corpus_labels: + if result > best_prob: + best_prob = result + best_label = n + best_labels.append(n) + test_acc = round(accuracy_score(best_labels, orig_labels), 4) + print('better:\t' + str(test_acc) + '\n', flush='True') + + return best_labels + def get_better_predictions(model, corpus, @@ -443,7 +482,7 @@ def get_better_predictions(model, test_acc = round(accuracy_score(all_labels, top_preds), 4) if print_results: - print('better:', '\t', test_acc, flush='') + print('better:\t' + str(test_acc), flush='True') return all_labels, all_preds