Skip to content
Snippets Groups Projects
Commit f3321b82 authored by emetheni's avatar emetheni
Browse files

correct generators

parent b0c32727
Branches
No related tags found
No related merge requests found
......@@ -50,19 +50,27 @@ train_sentences, dev_dict_sentences, test_dict_sentences = open_sentences(args.d
# make pandas dataframes
file_header = ['text', 'labels']
train_df = pd.DataFrame([[' '.join(x[-2]), x[-1]] for x in train_sentences], columns =file_header)
train_df = pd.DataFrame([[' '.join(x[-2]), x[-1]] for x in train_sentences],
columns =file_header)
train_df = train_df.sample(frac = 1) # shuffle the train
# get a global dev accuracy, we will not be directly using these results
dev_dict_df = {corpus:pd.DataFrame([[' '.join(x[-2]), x[-1]] for x in sents], columns = file_header)
for corpus, sents in dev_dict_sentences.items()}
test_dict_df = {corpus:pd.DataFrame([[' '.join(x[-2]), x[-1]] for x in sents], columns = file_header)
for corpus, sents in test_dict_sentences.items()}
dev_dict_df = {corpus : pd.DataFrame([[' '.join(x[-2]), x[-1]]
for x in sents],
columns = file_header)
for corpus, sents in dev_dict_sentences.items()}
test_dict_df = {corpus : pd.DataFrame([[' '.join(x[-2]), x[-1]]
for x in sents],
columns = file_header)
for corpus, sents in test_dict_sentences.items()}
#Make datasets from dataframes
train_dataset = datasets.Dataset.from_pandas(train_df)
dev_dict_dataset = {corpus:datasets.Dataset.from_pandas(dev_df) for dev_df in dev_dict_df}
test_dict_dataset = {corpus:datasets.Dataset.from_pandas(dev_df) for dev_df in dev_dict_df}
dev_dict_dataset = {corpus:datasets.Dataset.from_pandas(dev_df)
for corpus, dev_df in dev_dict_df.items()}
test_dict_dataset = {corpus:datasets.Dataset.from_pandas(dev_df)
for corpus, dev_df in test_dict_df.items()}
# get number of labels
num_labels = len(set([int(x.strip())
......
......@@ -231,7 +231,8 @@ def print_results_to_file(corpus, test_sentences, test_results,
label = test_results[n]
label = inv_mappings_dict[label]
if corpus in revert_substitutions:
label = revert_substitutions[corpus][label]
if label in revert_substitutions[corpus]:
label = revert_substitutions[corpus][label]
temp = sent[:-3] + [label]
assert len(temp) == 12
results_to_write.append(temp)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment