From eff2f26e921472c1b161b8307d4c6596d390a019 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Tue, 24 May 2022 10:19:41 +0200 Subject: [PATCH] update utils --- Configuration/config.ini | 2 +- Linker/Linker.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/Configuration/config.ini b/Configuration/config.ini index 64bba52..69d1a5c 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -26,7 +26,7 @@ dropout=0.1 sinkhorn_iters=3 [MODEL_TRAINING] -batch_size=16 +batch_size=32 epoch=30 seed_val=42 learning_rate=2e-4 diff --git a/Linker/Linker.py b/Linker/Linker.py index 8429260..f76404f 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -257,7 +257,8 @@ class Linker(Module): logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) # Run the kinker on the categories predictions - logits_predictions = self(batch_atoms, batch_pos_idx, batch_neg_idx, sentences_embedding, batch_sentences_mask) + logits_predictions = self(batch_atoms, batch_pos_idx, batch_neg_idx, sentences_embedding, + batch_sentences_mask) linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links) # Perform a backward pass to calculate the gradients. @@ -366,9 +367,10 @@ class Linker(Module): :param positional_ids: A List of batch_size elements, each being a List of num_atoms LongTensors. Each LongTensor in positional_ids[b][a] indexes the location of atoms of type a in sentence b. - :param device: + :param atom_type: :return: """ - return [[bsd_tensor.select(0, index=i).index_select(0, index=atom.to(self.device)) for atom in sentence] - for i, sentence in enumerate(positional_ids[atom_map_redux[atom_type]])] \ No newline at end of file + return torch.stack([torch.stack([bsd_tensor.select(0, index=i).select(0,index=int(atom)) + if atom != -1 else torch.zeros(self.dim_embedding_atoms) for atom in sentence]) + for i, sentence in enumerate(positional_ids[:, atom_map_redux[atom_type], :])]) -- GitLab