Skip to content
Snippets Groups Projects
Commit 9a406a06 authored by emetheni's avatar emetheni
Browse files

fix pytorch bug

parent b27f964b
Branches
No related tags found
No related merge requests found
...@@ -121,7 +121,13 @@ class TransformerClassifier(nn.Module): ...@@ -121,7 +121,13 @@ class TransformerClassifier(nn.Module):
model = TransformerClassifier() model = TransformerClassifier()
def train(model, train_dataloader, dev_dict_dataloader, test_dict_sentences, epochs, specific_results): def train(model,
train_dataloader,
dev_dict_dataloader,
test_dict_sentences,
test_dict_dataloader,
epochs,
specific_results):
device = torch.device("cuda" if args.use_cuda else "cpu") device = torch.device("cuda" if args.use_cuda else "cpu")
...@@ -153,6 +159,7 @@ def train(model, train_dataloader, dev_dict_dataloader, test_dict_sentences, epo ...@@ -153,6 +159,7 @@ def train(model, train_dataloader, dev_dict_dataloader, test_dict_sentences, epo
for epoch_num in range(0, epochs): for epoch_num in range(0, epochs):
print('\n=== Epoch {:} / {:} ==='.format(epoch_num + 1, epochs)) print('\n=== Epoch {:} / {:} ==='.format(epoch_num + 1, epochs))
model.train() model.train()
total_acc_train = 0 total_acc_train = 0
...@@ -182,16 +189,22 @@ def train(model, train_dataloader, dev_dict_dataloader, test_dict_sentences, epo ...@@ -182,16 +189,22 @@ def train(model, train_dataloader, dev_dict_dataloader, test_dict_sentences, epo
# Dev results for each corpus. We don't need to save the results. # Dev results for each corpus. We don't need to save the results.
for corpus in dev_dict_dataloader: for corpus in dev_dict_dataloader:
_ = get_predictions(model, corpus, dev_dict_dataloader[corpus]) _ = get_predictions(model,
corpus,
dev_dict_dataloader[corpus])
# we want the results of specific epochs for specific corpora. # we want the results of specific epochs for specific corpora.
# we define the epochs and the corpora and we save only these results. # we define the epochs and the corpora and we save only these results.
if epoch_num+1 in specific_results: if epoch_num+1 in specific_results:
if corpus in specific_results[epoch_num+1]: for corpus in specific_results[epoch_num+1]:
test_results = get_predictions(model, corpus, dev_dict_dataloader[corpus], test_results = get_predictions(model,
corpus,
test_dict_dataloader[corpus],
print_results=False) print_results=False)
print_results_to_file(corpus, test_dict_sentences[corpus], test_results, print_results_to_file(corpus,
test_dict_sentences[corpus],
test_results,
inv_mappings, substitutions_file) inv_mappings, substitutions_file)
...@@ -200,13 +213,26 @@ def train(model, train_dataloader, dev_dict_dataloader, test_dict_sentences, epo ...@@ -200,13 +213,26 @@ def train(model, train_dataloader, dev_dict_dataloader, test_dict_sentences, epo
print('\nModel: ', args.transformer_model) print('\nModel: ', args.transformer_model)
print('Batch size: ', args.batch_size * args.gradient_accumulation_steps) print('Batch size: ', args.batch_size * args.gradient_accumulation_steps)
print('\nStart training...\n') print('\nStart training...\n')
train(model, train_dataloader, dev_dict_dataloader, test_dict_sentences, args.num_epochs, specific_results) train(model,
train_dataloader,
dev_dict_dataloader,
test_dict_sentences,
test_dict_dataloader,
args.num_epochs,
specific_results)
print('\nTraining Done!') print('\nTraining Done!')
# ------- Testing --------- # ------- Testing ---------
print('Testing...')
for corpus in test_dict_dataloader: for corpus in test_dict_dataloader:
test_results = get_predictions(model, corpus, test_dict_dataloader[corpus]) test_results = get_predictions(model,
print_results_to_file(corpus, test_dict_sentences[corpus], test_results, corpus,
inv_mappings, substitutions_file) test_dict_dataloader[corpus]
\ No newline at end of file )
# print_results_to_file(corpus,
# test_dict_sentences[corpus],
# test_results,
# inv_mappings,
# substitutions_file)
\ No newline at end of file
...@@ -141,7 +141,10 @@ def open_sentences(path_to_corpora, mappings_dict): ...@@ -141,7 +141,10 @@ def open_sentences(path_to_corpora, mappings_dict):
# Testing functions # Testing functions
# =============== # ===============
def get_predictions(model, corpus, test_dataloader, print_results=True): def get_predictions(model,
corpus,
test_dataloader,
print_results=True):
''' Function to get the model's predictions for one corpus' test set. ''' Function to get the model's predictions for one corpus' test set.
Can print accuracy using scikit-learn. Can print accuracy using scikit-learn.
...@@ -183,7 +186,10 @@ def get_predictions(model, corpus, test_dataloader, print_results=True): ...@@ -183,7 +186,10 @@ def get_predictions(model, corpus, test_dataloader, print_results=True):
return all_preds return all_preds
def get_predictions_huggingface(trainer, corpus, test_set, print_results=True): def get_predictions_huggingface(trainer,
corpus,
test_set,
print_results=True):
''' SPECIFI FUNCTION FOR THE HUGGINGFACE TRAINER. ''' SPECIFI FUNCTION FOR THE HUGGINGFACE TRAINER.
Function to get the model's predictions for one corpus' test set. Function to get the model's predictions for one corpus' test set.
...@@ -203,8 +209,11 @@ def get_predictions_huggingface(trainer, corpus, test_set, print_results=True): ...@@ -203,8 +209,11 @@ def get_predictions_huggingface(trainer, corpus, test_set, print_results=True):
return preds return preds
def print_results_to_file(corpus, test_sentences, test_results, def print_results_to_file(corpus,
inv_mappings_dict, substitutions_file): test_sentences,
test_results,
inv_mappings_dict,
substitutions_file):
''' Function to print a new file with the test predictions per ''' Function to print a new file with the test predictions per
the specifications of the Shared task. the specifications of the Shared task.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment