Skip to content
Snippets Groups Projects
Commit 77420c0b authored by Julien Rabault's avatar Julien Rabault
Browse files
# Conflicts:
#	.gitignore
parents d7dd442e ebe65af1
Branches
No related tags found
1 merge request!1Draft: Master
...@@ -5,17 +5,17 @@ dim_encoder = 768 ...@@ -5,17 +5,17 @@ dim_encoder = 768
dim_decoder = 768 dim_decoder = 768
num_rnn_layers=1 num_rnn_layers=1
dropout=0.1 dropout=0.1
teacher_forcing=0.8 teacher_forcing=0.05
symbols_vocab_size=26 symbols_vocab_size=26
max_symbols_in_sentence=1250 max_symbols_in_sentence=1250
max_len_sentence=112 max_len_sentence=112
[MODEL_TRAINING] [MODEL_TRAINING]
device=cpu device=cpu
batch_size=16 batch_size=32
epoch=20 epoch=20
seed_val=42 seed_val=42
learning_rate=0.005 learning_rate=0.005
use_checkpoint_SAVE=1 use_checkpoint_SAVE=0
output_path=Output output_path=Output
use_checkpoint_LOAD=0 use_checkpoint_LOAD=0
input_path=Input input_path=Input
......
File added
File added
File added
...@@ -2,7 +2,7 @@ import random ...@@ -2,7 +2,7 @@ import random
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import (Dropout, Module, Linear, LSTM) from torch.nn import (Dropout, Module, ModuleList, Linear, LSTM, GRU)
from Configuration import Configuration from Configuration import Configuration
from SuperTagger.Symbol.SymbolEmbedding import SymbolEmbedding from SuperTagger.Symbol.SymbolEmbedding import SymbolEmbedding
...@@ -18,6 +18,8 @@ class RNNDecoderLayer(Module): ...@@ -18,6 +18,8 @@ class RNNDecoderLayer(Module):
self.symbols_vocab_size = int(Configuration.modelDecoderConfig['symbols_vocab_size']) self.symbols_vocab_size = int(Configuration.modelDecoderConfig['symbols_vocab_size'])
dropout = float(Configuration.modelDecoderConfig['dropout']) dropout = float(Configuration.modelDecoderConfig['dropout'])
self.num_rnn_layers = int(Configuration.modelDecoderConfig['num_rnn_layers']) self.num_rnn_layers = int(Configuration.modelDecoderConfig['num_rnn_layers'])
self.teacher_forcing = float(Configuration.modelDecoderConfig['teacher_forcing'])
self.bidirectional = False self.bidirectional = False
self.symbols_map = symbols_map self.symbols_map = symbols_map
self.symbols_padding_id = self.symbols_map["[PAD]"] self.symbols_padding_id = self.symbols_map["[PAD]"]
...@@ -25,8 +27,6 @@ class RNNDecoderLayer(Module): ...@@ -25,8 +27,6 @@ class RNNDecoderLayer(Module):
self.symbols_start_id = self.symbols_map["[START]"] self.symbols_start_id = self.symbols_map["[START]"]
self.symbols_sos_id = self.symbols_map["[SOS]"] self.symbols_sos_id = self.symbols_map["[SOS]"]
self.teacher_forcing = float(Configuration.modelDecoderConfig['teacher_forcing'])
# Different layers # Different layers
# Symbols Embedding # Symbols Embedding
self.symbols_embedder = SymbolEmbedding(self.hidden_dim, self.symbols_vocab_size, self.symbols_embedder = SymbolEmbedding(self.hidden_dim, self.symbols_vocab_size,
...@@ -35,7 +35,14 @@ class RNNDecoderLayer(Module): ...@@ -35,7 +35,14 @@ class RNNDecoderLayer(Module):
self.dropout = Dropout(dropout) self.dropout = Dropout(dropout)
# rnn Layer # rnn Layer
self.rnn = LSTM(input_size=self.hidden_dim, hidden_size=self.hidden_dim, num_layers=self.num_rnn_layers, self.rnn = LSTM(input_size=self.hidden_dim, hidden_size=self.hidden_dim, num_layers=self.num_rnn_layers,
dropout=dropout,
bidirectional=self.bidirectional, batch_first=True) bidirectional=self.bidirectional, batch_first=True)
self.intermediate = ModuleList()
for _ in range(3):
self.intermediate.append(Linear(self.hidden_dim, self.hidden_dim))
self.activation = F.gelu
# Projection on vocab_size # Projection on vocab_size
if self.bidirectional: if self.bidirectional:
self.proj = Linear(self.hidden_dim * 2, self.symbols_vocab_size) self.proj = Linear(self.hidden_dim * 2, self.symbols_vocab_size)
...@@ -57,22 +64,22 @@ class RNNDecoderLayer(Module): ...@@ -57,22 +64,22 @@ class RNNDecoderLayer(Module):
# y_hat[batch_size, max_len_sentence, vocab_size] init with probability pad =1 # y_hat[batch_size, max_len_sentence, vocab_size] init with probability pad =1
y_hat = torch.zeros(batch_size, self.max_symbols_in_sentence, self.symbols_vocab_size, y_hat = torch.zeros(batch_size, self.max_symbols_in_sentence, self.symbols_vocab_size,
dtype=torch.float,device="cuda" if torch.cuda.is_available() else "cpu") dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu")
y_hat[:, :, self.symbols_padding_id] = 1 y_hat[:, :, self.symbols_padding_id] = 1
use_teacher_forcing = True if random.random() < 0.05 else False
if use_teacher_forcing:
print("\n FORCING TEACHING \n")
decoded_i = symbols_tokenized_batch[:, 0].unsqueeze(1)
else :
decoded_i = torch.ones(batch_size, 1, dtype=torch.long,device="cuda" if torch.cuda.is_available() else "cpu")* self.symbols_start_id
sos_mask = torch.zeros(batch_size, dtype=torch.bool,device="cuda" if torch.cuda.is_available() else "cpu") decoded_i = torch.ones(batch_size, 1, dtype=torch.long,
device="cuda" if torch.cuda.is_available() else "cpu") * self.symbols_start_id
sos_mask = torch.zeros(batch_size, dtype=torch.bool, device="cuda" if torch.cuda.is_available() else "cpu")
# hidden_state goes through multiple linear layers # hidden_state goes through multiple linear layers
hidden_state = pooler_output.unsqueeze(0).repeat(self.num_rnn_layers * (1 + self.bidirectional), 1, 1) hidden_state = pooler_output.unsqueeze(0).repeat(self.num_rnn_layers * (1 + self.bidirectional), 1, 1)
c_state = torch.zeros(self.num_rnn_layers * (1 + self.bidirectional), batch_size, hidden_size, c_state = torch.zeros(self.num_rnn_layers * (1 + self.bidirectional), batch_size, hidden_size,
dtype=torch.float,device="cuda" if torch.cuda.is_available() else "cpu") dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu")
for intermediate in self.intermediate:
hidden_state = self.dropout(self.activation(intermediate(hidden_state)))
use_teacher_forcing = True if random.random() < self.teacher_forcing else False
# for each symbol # for each symbol
for i in range(self.max_symbols_in_sentence): for i in range(self.max_symbols_in_sentence):
...@@ -97,8 +104,6 @@ class RNNDecoderLayer(Module): ...@@ -97,8 +104,6 @@ class RNNDecoderLayer(Module):
y_hat[~sos_mask, i, :-2] = proj[~sos_mask, -1, :] y_hat[~sos_mask, i, :-2] = proj[~sos_mask, -1, :]
sos_mask = sos_mask_i | sos_mask sos_mask = sos_mask_i | sos_mask
use_teacher_forcing = True if random.random() < 0.25 else False
# Stop if every sentence says padding or if we are full # Stop if every sentence says padding or if we are full
if not torch.any(~sos_mask): if not torch.any(~sos_mask):
break break
...@@ -116,20 +121,23 @@ class RNNDecoderLayer(Module): ...@@ -116,20 +121,23 @@ class RNNDecoderLayer(Module):
# contains the predictions # contains the predictions
y_hat = torch.zeros(batch_size, self.max_symbols_in_sentence, self.symbols_vocab_size, y_hat = torch.zeros(batch_size, self.max_symbols_in_sentence, self.symbols_vocab_size,
dtype=torch.float,device="cuda" if torch.cuda.is_available() else "cpu") dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu")
y_hat[:, :, self.symbols_padding_id] = 1 y_hat[:, :, self.symbols_padding_id] = 1
# input of the embedder, a created vector that replace the true value # input of the embedder, a created vector that replace the true value
decoded_i = torch.ones(batch_size, 1, dtype=torch.long,device="cuda" if torch.cuda.is_available() else "cpu") * self.symbols_start_id decoded_i = torch.ones(batch_size, 1, dtype=torch.long,
device="cuda" if torch.cuda.is_available() else "cpu") * self.symbols_start_id
sos_mask = torch.zeros(batch_size, dtype=torch.bool,device="cuda" if torch.cuda.is_available() else "cpu") sos_mask = torch.zeros(batch_size, dtype=torch.bool, device="cuda" if torch.cuda.is_available() else "cpu")
hidden_state = pooler_output.unsqueeze(0).repeat(self.num_rnn_layers * (1 + self.bidirectional), 1, 1) hidden_state = pooler_output.unsqueeze(0).repeat(self.num_rnn_layers * (1 + self.bidirectional), 1, 1)
c_state = torch.zeros(self.num_rnn_layers * (1 + self.bidirectional), batch_size, hidden_size, c_state = torch.zeros(self.num_rnn_layers * (1 + self.bidirectional), batch_size, hidden_size,
dtype=torch.float,device="cuda" if torch.cuda.is_available() else "cpu") dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu")
for intermediate in self.intermediate:
hidden_state = self.dropout(self.activation(intermediate(hidden_state)))
for i in range(self.max_symbols_in_sentence): for i in range(self.max_symbols_in_sentence):
symbols_embedding = self.symbols_embedder(decoded_i) symbols_embedding = self.symbols_embedder(decoded_i)
#symbols_embedding = self.dropout(symbols_embedding) symbols_embedding = self.dropout(symbols_embedding)
output, (hidden_state, c_state) = self.rnn(symbols_embedding, (hidden_state, c_state)) output, (hidden_state, c_state) = self.rnn(symbols_embedding, (hidden_state, c_state))
......
...@@ -41,7 +41,7 @@ class EncoderDecoder(Module): ...@@ -41,7 +41,7 @@ class EncoderDecoder(Module):
symbols_tokenized_batch: [batch_size, max_symbols_in_sentence] the true symbols for each sentence. symbols_tokenized_batch: [batch_size, max_symbols_in_sentence] the true symbols for each sentence.
""" """
last_hidden_state, pooler_output = self.encoder([sents_tokenized_batch, sents_mask_batch]) last_hidden_state, pooler_output = self.encoder([sents_tokenized_batch, sents_mask_batch])
#last_hidden_state = self.dropout(last_hidden_state) last_hidden_state = self.dropout(last_hidden_state)
return self.decoder(symbols_tokenized_batch, last_hidden_state, pooler_output) return self.decoder(symbols_tokenized_batch, last_hidden_state, pooler_output)
def decode_greedy_rnn(self, sents_tokenized_batch, sents_mask_batch): def decode_greedy_rnn(self, sents_tokenized_batch, sents_mask_batch):
...@@ -52,7 +52,7 @@ class EncoderDecoder(Module): ...@@ -52,7 +52,7 @@ class EncoderDecoder(Module):
sents_mask_batch : mask output from the encoder tokenizer sents_mask_batch : mask output from the encoder tokenizer
""" """
last_hidden_state, pooler_output = self.encoder([sents_tokenized_batch, sents_mask_batch]) last_hidden_state, pooler_output = self.encoder([sents_tokenized_batch, sents_mask_batch])
#last_hidden_state = self.dropout(last_hidden_state) last_hidden_state = self.dropout(last_hidden_state)
predictions = self.decoder.predict_rnn(last_hidden_state, pooler_output) predictions = self.decoder.predict_rnn(last_hidden_state, pooler_output)
...@@ -78,10 +78,11 @@ class EncoderDecoder(Module): ...@@ -78,10 +78,11 @@ class EncoderDecoder(Module):
print("\nsub true (", l, ") : ", print("\nsub true (", l, ") : ",
[token for token in true_trad if token != '[PAD]']) [token for token in true_trad if token != '[PAD]'])
print("\nsub predict (", len([i for i in predict_trad if i != '[PAD]']), ") : ", print("\nsub predict (", len([i for i in predict_trad if i != '[PAD]']), ") : ",
[token for token in predict_trad[:l] if token != '[PAD]']) [token for token in predict_trad if token != '[PAD]'])
return measure_supertagging_accuracy(pred, b_symbols_tokenized, return measure_supertagging_accuracy(pred, b_symbols_tokenized,
ignore_idx=self.symbols_map["[PAD]"]), float(cross_entropy_loss(type_predictions, b_symbols_tokenized)) ignore_idx=self.symbols_map["[PAD]"]), float(
cross_entropy_loss(type_predictions, b_symbols_tokenized))
def eval_epoch(self, dataloader, cross_entropy_loss): def eval_epoch(self, dataloader, cross_entropy_loss):
r"""Average the evaluation of all the batch. r"""Average the evaluation of all the batch.
...@@ -93,7 +94,7 @@ class EncoderDecoder(Module): ...@@ -93,7 +94,7 @@ class EncoderDecoder(Module):
for step, batch in enumerate(dataloader): for step, batch in enumerate(dataloader):
batch = batch batch = batch
batch_output , loss = self.eval_batch(batch, cross_entropy_loss) batch_output, loss = self.eval_batch(batch, cross_entropy_loss)
((bs_correct, bs_total), (bw_correct, bw_total)) = batch_output ((bs_correct, bs_total), (bw_correct, bw_total)) = batch_output
s_total += bs_total s_total += bs_total
s_correct += bs_correct s_correct += bs_correct
......
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
from tqdm import tqdm from tqdm import tqdm
def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500): def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500):
print("\n" + "#" * 20) print("\n" + "#" * 20)
print("Loading csv...") print("Loading csv...")
...@@ -38,8 +37,8 @@ def format_time(elapsed): ...@@ -38,8 +37,8 @@ def format_time(elapsed):
# Format as hh:mm:ss # Format as hh:mm:ss
return str(datetime.timedelta(seconds=elapsed_rounded)) return str(datetime.timedelta(seconds=elapsed_rounded))
def checkpoint_save(model, opt, epoch, dir, loss):
def checkpoint_save(model, opt, epoch, dir, loss):
torch.save({ torch.save({
'epoch': epoch, 'epoch': epoch,
'model_state_dict': model.state_dict(), 'model_state_dict': model.state_dict(),
...@@ -51,7 +50,7 @@ def checkpoint_save(model, opt, epoch, dir, loss): ...@@ -51,7 +50,7 @@ def checkpoint_save(model, opt, epoch, dir, loss):
def checkpoint_load(model, opt, path): def checkpoint_load(model, opt, path):
epoch = 0 epoch = 0
loss = 0 loss = 0
print("#" *15) print("#" * 15)
try: try:
checkpoint = torch.load(path) checkpoint = torch.load(path)
...@@ -62,8 +61,6 @@ def checkpoint_load(model, opt, path): ...@@ -62,8 +61,6 @@ def checkpoint_load(model, opt, path):
print("\n The loading checkpoint was successful ! \n") print("\n The loading checkpoint was successful ! \n")
print("#" * 10) print("#" * 10)
except Exception as e: except Exception as e:
print("\nCan't load checkpoint model because : "+ str(e) +"\n\nUse default model \n") print("\nCan't load checkpoint model because : " + str(e) + "\n\nUse default model \n")
print("#" * 15) print("#" * 15)
return model, opt, epoch, loss return model, opt, epoch, loss
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
#SBATCH --gres=gpu:1 #SBATCH --gres=gpu:1
#SBATCH --mem=32000 #SBATCH --mem=32000
#SBATCH --gres-flags=enforce-binding #SBATCH --gres-flags=enforce-binding
#SBATCH --error="/users/celdev/jrabault/PNRIA - DeepGrail/OUT/error_rtx1.err" #SBATCH --error="error_rtx1.err"
#SBATCH --output="/users/celdev/jrabault/PNRIA - DeepGrail/OUT/out_rtx1.out" #SBATCH --output="out_rtx1.out"
module purge module purge
module load singularity/3.0.3 module load singularity/3.0.3
srun singularity exec /logiciels/containerCollections/CUDA11/pytorch-NGC-21-03-py3.sif python "/users/celdev/jrabault/PNRIA - DeepGrail/train.py" srun singularity exec /logiciels/containerCollections/CUDA11/pytorch-NGC-21-03-py3.sif python "train.py"
\ No newline at end of file \ No newline at end of file
# This is a sample Python script.
# Press Maj+F10 to execute it or replace it with your code.
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
def print_hi(name):
# Use a breakpoint in the code line below to debug your script.
print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint.
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
print_hi('PyCharm')
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
https://gitlab-ci-token:glpat-AZdpzmAPDFCSK8nPZxCw@gitlab.irit.fr/pnria/global-helper/deepgrail-rnn.git
\ No newline at end of file
import math import os
import os import os
import time import time
from datetime import datetime
import numpy as np import numpy as np
import torch import torch
from torch.optim import AdamW
import transformers
import torch.nn.functional as F import torch.nn.functional as F
import transformers
from torch.optim import SGD
from torch.utils.data import Dataset, TensorDataset, random_split from torch.utils.data import Dataset, TensorDataset, random_split
from transformers import (AutoTokenizer, get_cosine_schedule_with_warmup) from transformers import (AutoTokenizer, get_cosine_schedule_with_warmup)
from transformers import (CamembertModel) from transformers import (CamembertModel)
from torch.utils.tensorboard import SummaryWriter
from Configuration import Configuration from Configuration import Configuration
from SuperTagger.Encoder.EncoderInput import EncoderInput from SuperTagger.Encoder.EncoderInput import EncoderInput
...@@ -20,7 +20,7 @@ from SuperTagger.Symbol.symbol_map import symbol_map ...@@ -20,7 +20,7 @@ from SuperTagger.Symbol.symbol_map import symbol_map
from SuperTagger.eval import NormCrossEntropy from SuperTagger.eval import NormCrossEntropy
from SuperTagger.utils import format_time, read_csv_pgbar, checkpoint_save, checkpoint_load from SuperTagger.utils import format_time, read_csv_pgbar, checkpoint_save, checkpoint_load
from datetime import datetime from torch.utils.tensorboard import SummaryWriter
transformers.TOKENIZERS_PARALLELISM = True transformers.TOKENIZERS_PARALLELISM = True
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -38,8 +38,7 @@ num_gru_layers = int(Configuration.modelDecoderConfig['num_rnn_layers']) ...@@ -38,8 +38,7 @@ num_gru_layers = int(Configuration.modelDecoderConfig['num_rnn_layers'])
file_path = 'Datasets/m2_dataset.csv' file_path = 'Datasets/m2_dataset.csv'
batch_size = int(Configuration.modelTrainingConfig['batch_size']) batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 10 nb_sentences = batch_size * 50
# Number of training epochs. The BERT authors recommend between 2 and 4.
epochs = int(Configuration.modelTrainingConfig['epoch']) epochs = int(Configuration.modelTrainingConfig['epoch'])
seed_val = int(Configuration.modelTrainingConfig['seed_val']) seed_val = int(Configuration.modelTrainingConfig['seed_val'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate']) learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
...@@ -57,7 +56,7 @@ logs_dir = os.path.join(training_dir, 'logs') ...@@ -57,7 +56,7 @@ logs_dir = os.path.join(training_dir, 'logs')
checkpoint_dir = training_dir checkpoint_dir = training_dir
writer = SummaryWriter(log_dir=logs_dir) writer = SummaryWriter(log_dir=logs_dir)
use_checkpoint_SAVE = bool(Configuration.modelTrainingConfig['use_checkpoint_SAVE']) use_checkpoint_SAVE = bool(Configuration.modelTrainingConfig.getboolean('use_checkpoint_SAVE'))
# endregion OutputTraining # endregion OutputTraining
...@@ -66,8 +65,7 @@ use_checkpoint_SAVE = bool(Configuration.modelTrainingConfig['use_checkpoint_SAV ...@@ -66,8 +65,7 @@ use_checkpoint_SAVE = bool(Configuration.modelTrainingConfig['use_checkpoint_SAV
input_path = str(Configuration.modelTrainingConfig['input_path']) input_path = str(Configuration.modelTrainingConfig['input_path'])
model_to_load = str(Configuration.modelTrainingConfig['model_to_load']) model_to_load = str(Configuration.modelTrainingConfig['model_to_load'])
model_to_load_path = os.path.join(input_path, model_to_load) model_to_load_path = os.path.join(input_path, model_to_load)
use_checkpoint_LOAD = bool(Configuration.modelTrainingConfig['use_checkpoint_LOAD']) use_checkpoint_LOAD = bool(Configuration.modelTrainingConfig.getboolean('use_checkpoint_LOAD'))
print(use_checkpoint_LOAD)
# endregion InputTraining # endregion InputTraining
...@@ -136,7 +134,7 @@ sents_tokenized, sents_mask = sents_tokenizer.fit_transform_tensors(df['Sentence ...@@ -136,7 +134,7 @@ sents_tokenized, sents_mask = sents_tokenizer.fit_transform_tensors(df['Sentence
dataset = TensorDataset(sents_tokenized, sents_mask, symbols_tokenized) dataset = TensorDataset(sents_tokenized, sents_mask, symbols_tokenized)
# Calculate the number of samples to include in each set. # Calculate the number of samples to include in each set.
train_size = int(0.95 * len(dataset)) train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size val_size = len(dataset) - train_size
# Divide the dataset by randomly selecting samples. # Divide the dataset by randomly selecting samples.
...@@ -153,23 +151,20 @@ validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batc ...@@ -153,23 +151,20 @@ validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batc
# region Fit tunning # region Fit tunning
# Optimizer # Optimizer
# optimizer_encoder = AdamW(model.encoder.parameters(), optimizer_encoder = SGD(model.encoder.parameters(),
# lr=5e-5, lr=5e-5)
# eps=1e-8) optimizer_decoder = SGD(model.decoder.parameters(),
# optimizer_decoder = AdamW(model.decoder.parameters(), lr=learning_rate)
# lr=learning_rate,
# eps=1e-8)
optimizer = AdamW(model.parameters(),
lr=learning_rate,
eps=1e-8)
# Total number of training steps is [number of batches] x [number of epochs]. # Total number of training steps is [number of batches] x [number of epochs].
# (Note that this is not the same as the number of training samples). # (Note that this is not the same as the number of training samples).
total_steps = len(training_dataloader) * epochs total_steps = len(training_dataloader) * epochs
# Create the learning rate scheduler. # Create the learning rate scheduler.
scheduler = get_cosine_schedule_with_warmup(optimizer, scheduler_encoder = get_cosine_schedule_with_warmup(optimizer_encoder,
num_warmup_steps=0,
num_training_steps=5)
scheduler_decoder = get_cosine_schedule_with_warmup(optimizer_decoder,
num_warmup_steps=0, num_warmup_steps=0,
num_training_steps=total_steps) num_training_steps=total_steps)
...@@ -199,7 +194,7 @@ total_t0 = time.time() ...@@ -199,7 +194,7 @@ total_t0 = time.time()
validate = True validate = True
if use_checkpoint_LOAD: if use_checkpoint_LOAD:
model, optimizer, last_epoch, loss = checkpoint_load(model, optimizer, model_to_load_path) model, optimizer_decoder, last_epoch, loss = checkpoint_load(model, optimizer_decoder, model_to_load_path)
epochs = epochs - last_epoch epochs = epochs - last_epoch
...@@ -242,46 +237,40 @@ def run_epochs(epochs): ...@@ -242,46 +237,40 @@ def run_epochs(epochs):
b_sents_mask = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") b_sents_mask = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
b_symbols_tokenized = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") b_symbols_tokenized = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
# optimizer_encoder.zero_grad() optimizer_encoder.zero_grad()
# optimizer_decoder.zero_grad() optimizer_decoder.zero_grad()
optimizer.zero_grad()
logits_predictions = model(b_sents_tokenized, b_sents_mask, b_symbols_tokenized) logits_predictions = model(b_sents_tokenized, b_sents_mask, b_symbols_tokenized)
predict_trad = [{v: k for k, v in symbol_map.items()}[int(i)] for i in predict_trad = [{v: k for k, v in symbol_map.items()}[int(i)] for i in
torch.argmax(F.softmax(logits_predictions, dim=2), dim=2)[0]] torch.argmax(F.softmax(logits_predictions, dim=2), dim=2)[0]]
true_trad = [{v: k for k, v in symbol_map.items()}[int(i)] for i in b_symbols_tokenized[0]] true_trad = [{v: k for k, v in symbol_map.items()}[int(i)] for i in b_symbols_tokenized[0]]
l = len([i for i in true_trad if i != '[PAD]']) l = len([i for i in true_trad if i != '[PAD]'])
if step % 40 == 0 and not step == 0: if step % 40 == 0 and not step == 0:
writer.add_text("Sample", "\ntrain true (" + str(l) + ") : " + str([token for token in true_trad if token != '[PAD]']) + "\ntrain predict (" + str(len([i for i in predict_trad if i != '[PAD]'])) + ") : " + str([token for token in predict_trad[:l] if token != '[PAD]'])) writer.add_text("Sample", "\ntrain true (" + str(l) + ") : " + str(
[token for token in true_trad if token != '[PAD]']) + "\ntrain predict (" + str(
len([i for i in predict_trad if i != '[PAD]'])) + ") : " + str(
[token for token in predict_trad[:l] if token != '[PAD]']))
loss = cross_entropy_loss(logits_predictions, b_symbols_tokenized) loss = cross_entropy_loss(logits_predictions, b_symbols_tokenized)
# Perform a backward pass to calculate the gradients. # Perform a backward pass to calculate the gradients.
total_train_loss += float(loss) total_train_loss += float(loss)
loss.backward() loss.backward()
# Clip the norm of the gradients to 1.0.
# This is to help prevent the "exploding gradients" problem. # This is to help prevent the "exploding gradients" problem.
# torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)
# torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), 1.0)
# Update parameters and take a step using the computed gradient. # Update parameters and take a step using the computed gradient.
# The optimizer dictates the "update rule"--how the parameters are optimizer_encoder.step()
# modified based on their gradients, the learning rate, etc. optimizer_decoder.step()
# optimizer_encoder.step()
# optimizer_decoder.step()
optimizer.step() scheduler_encoder.step()
scheduler_decoder.step()
scheduler.step()
# checkpoint # checkpoint
if use_checkpoint_SAVE: if use_checkpoint_SAVE:
checkpoint_save(model, optimizer, epoch_i, checkpoint_dir, loss) checkpoint_save(model, optimizer_decoder, epoch_i, checkpoint_dir, loss)
avg_train_loss = total_train_loss / len(training_dataloader) avg_train_loss = total_train_loss / len(training_dataloader)
...@@ -313,5 +302,3 @@ def run_epochs(epochs): ...@@ -313,5 +302,3 @@ def run_epochs(epochs):
run_epochs(epochs) run_epochs(epochs)
# endregion Train # endregion Train
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment