diff --git a/Configuration/config.ini b/Configuration/config.ini index 64bba529b04897e70ebaee8328c7ad7275826921..69d1a5c600d73b737e8cfd1215cde5379f1cda52 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -26,7 +26,7 @@ dropout=0.1 sinkhorn_iters=3 [MODEL_TRAINING] -batch_size=16 +batch_size=32 epoch=30 seed_val=42 learning_rate=2e-4 diff --git a/Linker/Linker.py b/Linker/Linker.py index 8429260558d0a192c6240a4e85f25da841df3f94..f76404fc7566c31f94eb70b2930e5d60394582e1 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -257,7 +257,8 @@ class Linker(Module): logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) # Run the kinker on the categories predictions - logits_predictions = self(batch_atoms, batch_pos_idx, batch_neg_idx, sentences_embedding, batch_sentences_mask) + logits_predictions = self(batch_atoms, batch_pos_idx, batch_neg_idx, sentences_embedding, + batch_sentences_mask) linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links) # Perform a backward pass to calculate the gradients. @@ -366,9 +367,10 @@ class Linker(Module): :param positional_ids: A List of batch_size elements, each being a List of num_atoms LongTensors. Each LongTensor in positional_ids[b][a] indexes the location of atoms of type a in sentence b. - :param device: + :param atom_type: :return: """ - return [[bsd_tensor.select(0, index=i).index_select(0, index=atom.to(self.device)) for atom in sentence] - for i, sentence in enumerate(positional_ids[atom_map_redux[atom_type]])] \ No newline at end of file + return torch.stack([torch.stack([bsd_tensor.select(0, index=i).select(0,index=int(atom)) + if atom != -1 else torch.zeros(self.dim_embedding_atoms) for atom in sentence]) + for i, sentence in enumerate(positional_ids[:, atom_map_redux[atom_type], :])])