Skip to content
Snippets Groups Projects
Commit f3321b82 authored by emetheni's avatar emetheni
Browse files

correct generators

parent b0c32727
No related branches found
No related tags found
No related merge requests found
......@@ -50,19 +50,27 @@ train_sentences, dev_dict_sentences, test_dict_sentences = open_sentences(args.d
# make pandas dataframes
file_header = ['text', 'labels']
train_df = pd.DataFrame([[' '.join(x[-2]), x[-1]] for x in train_sentences], columns =file_header)
train_df = pd.DataFrame([[' '.join(x[-2]), x[-1]] for x in train_sentences],
columns =file_header)
train_df = train_df.sample(frac = 1) # shuffle the train
# get a global dev accuracy, we will not be directly using these results
dev_dict_df = {corpus:pd.DataFrame([[' '.join(x[-2]), x[-1]] for x in sents], columns = file_header)
for corpus, sents in dev_dict_sentences.items()}
test_dict_df = {corpus:pd.DataFrame([[' '.join(x[-2]), x[-1]] for x in sents], columns = file_header)
for corpus, sents in test_dict_sentences.items()}
dev_dict_df = {corpus : pd.DataFrame([[' '.join(x[-2]), x[-1]]
for x in sents],
columns = file_header)
for corpus, sents in dev_dict_sentences.items()}
test_dict_df = {corpus : pd.DataFrame([[' '.join(x[-2]), x[-1]]
for x in sents],
columns = file_header)
for corpus, sents in test_dict_sentences.items()}
#Make datasets from dataframes
train_dataset = datasets.Dataset.from_pandas(train_df)
dev_dict_dataset = {corpus:datasets.Dataset.from_pandas(dev_df) for dev_df in dev_dict_df}
test_dict_dataset = {corpus:datasets.Dataset.from_pandas(dev_df) for dev_df in dev_dict_df}
dev_dict_dataset = {corpus:datasets.Dataset.from_pandas(dev_df)
for corpus, dev_df in dev_dict_df.items()}
test_dict_dataset = {corpus:datasets.Dataset.from_pandas(dev_df)
for corpus, dev_df in test_dict_df.items()}
# get number of labels
num_labels = len(set([int(x.strip())
......
......@@ -231,7 +231,8 @@ def print_results_to_file(corpus, test_sentences, test_results,
label = test_results[n]
label = inv_mappings_dict[label]
if corpus in revert_substitutions:
label = revert_substitutions[corpus][label]
if label in revert_substitutions[corpus]:
label = revert_substitutions[corpus][label]
temp = sent[:-3] + [label]
assert len(temp) == 12
results_to_write.append(temp)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment