Skip to content
Snippets Groups Projects
Commit f920249c authored by Julien Rabault's avatar Julien Rabault
Browse files

Fix device

parent 35345fe6
Branches
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -11,7 +11,6 @@ from torch.optim import AdamW ...@@ -11,7 +11,6 @@ from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split from torch.utils.data import TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup
from Configuration import Configuration from Configuration import Configuration
from Linker.AtomEmbedding import AtomEmbedding from Linker.AtomEmbedding import AtomEmbedding
...@@ -92,6 +91,9 @@ class Linker(Module): ...@@ -92,6 +91,9 @@ class Linker(Module):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1): def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
r""" r"""
Args: Args:
...@@ -161,13 +163,13 @@ class Linker(Module): ...@@ -161,13 +163,13 @@ class Linker(Module):
[get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized,
atoms_polarity_batch, atom_type, s_idx) atoms_polarity_batch, atom_type, s_idx)
for s_idx in range(len(atoms_polarity_batch))], padding_value=0, for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2).to(self.device) max_len=self.max_atoms_in_one_type // 2)
neg_encoding = pad_sequence( neg_encoding = pad_sequence(
[get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized,
atoms_polarity_batch, atom_type, s_idx) atoms_polarity_batch, atom_type, s_idx)
for s_idx in range(len(atoms_polarity_batch))], padding_value=0, for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2).to(self.device) max_len=self.max_atoms_in_one_type // 2)
pos_encoding = self.pos_transformation(pos_encoding) pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding) neg_encoding = self.neg_transformation(neg_encoding)
...@@ -195,8 +197,6 @@ class Linker(Module): ...@@ -195,8 +197,6 @@ class Linker(Module):
""" """
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
validation_rate) validation_rate)
self.to(self.device)
if checkpoint or tensorboard: if checkpoint or tensorboard:
checkpoint_dir, writer = output_create_dir() checkpoint_dir, writer = output_create_dir()
...@@ -210,8 +210,7 @@ class Linker(Module): ...@@ -210,8 +210,7 @@ class Linker(Module):
print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}') print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%') print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
if validation_rate > 0.0: if validation_rate > 0.0:
with torch.no_grad(): loss_test, accuracy_test = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
loss_test, accuracy_test = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%') print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')
if checkpoint: if checkpoint:
...@@ -241,14 +240,13 @@ class Linker(Module): ...@@ -241,14 +240,13 @@ class Linker(Module):
accuracy on validation set accuracy on validation set
loss on train set loss on train set
""" """
self.train()
# Reset the total loss for this epoch. # Reset the total loss for this epoch.
epoch_loss = 0 epoch_loss = 0
accuracy_train = 0 accuracy_train = 0
t0 = time.time() t0 = time.time()
self.train()
# For each batch of training data... # For each batch of training data...
with tqdm(training_dataloader, unit="batch") as tepoch: with tqdm(training_dataloader, unit="batch") as tepoch:
for batch in tepoch: for batch in tepoch:
...@@ -299,44 +297,44 @@ class Linker(Module): ...@@ -299,44 +297,44 @@ class Linker(Module):
axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat) axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat)
""" """
self.eval() self.eval()
with torch.no_grad():
# get atoms # get atoms
atoms_batch = get_atoms_batch(categories) atoms_batch = get_atoms_batch(categories)
atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
# get polarities # get polarities
polarities = find_pos_neg_idexes(self.max_atoms_in_sentence, categories) polarities = find_pos_neg_idexes(self.max_atoms_in_sentence, categories)
# atoms embedding # atoms embedding
atoms_embedding = self.atoms_embedding(atoms_tokenized) atoms_embedding = self.atoms_embedding(atoms_tokenized)
# MHA ou LSTM avec sortie de BERT # MHA ou LSTM avec sortie de BERT
atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask,
self.make_decoder_mask(atoms_tokenized)) self.make_decoder_mask(atoms_tokenized))
link_weights = [] link_weights = []
for atom_type in list(self.atom_map.keys())[:-1]: for atom_type in list(self.atom_map.keys())[:-1]:
pos_encoding = pad_sequence( pos_encoding = pad_sequence(
[get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized, [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized,
polarities, atom_type, s_idx) polarities, atom_type, s_idx)
for s_idx in range(len(polarities))], padding_value=0, for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2).to(self.device) max_len=self.max_atoms_in_one_type // 2)
neg_encoding = pad_sequence( neg_encoding = pad_sequence(
[get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized, [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized,
polarities, atom_type, s_idx) polarities, atom_type, s_idx)
for s_idx in range(len(polarities))], padding_value=0, for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2).to(self.device) max_len=self.max_atoms_in_one_type // 2)
pos_encoding = self.pos_transformation(pos_encoding) pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding) neg_encoding = self.neg_transformation(neg_encoding)
weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
link_weights.append(sinkhorn(weights, iters=3)) link_weights.append(sinkhorn(weights, iters=3))
logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3) logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3)
axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3) axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3)
return axiom_links return axiom_links
def eval_batch(self, batch, cross_entropy_loss): def eval_batch(self, batch, cross_entropy_loss):
batch_atoms = batch[0].to(self.device) batch_atoms = batch[0].to(self.device)
...@@ -361,12 +359,14 @@ class Linker(Module): ...@@ -361,12 +359,14 @@ class Linker(Module):
Args: Args:
dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols
""" """
self.eval()
accuracy_average = 0 accuracy_average = 0
loss_average = 0 loss_average = 0
for step, batch in enumerate(dataloader): with torch.no_grad():
loss, accuracy = self.eval_batch(batch, cross_entropy_loss) for step, batch in enumerate(dataloader):
accuracy_average += accuracy loss, accuracy = self.eval_batch(batch, cross_entropy_loss)
loss_average += float(loss) accuracy_average += accuracy
loss_average += float(loss)
return loss_average / len(dataloader), accuracy_average / len(dataloader) return loss_average / len(dataloader), accuracy_average / len(dataloader)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment