Skip to content
Snippets Groups Projects
Commit eff2f26e authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

update utils

parent 140ce64e
No related branches found
No related tags found
3 merge requests!6Linker with transformer,!5Linker with transformer,!2Change preprocess
...@@ -26,7 +26,7 @@ dropout=0.1 ...@@ -26,7 +26,7 @@ dropout=0.1
sinkhorn_iters=3 sinkhorn_iters=3
[MODEL_TRAINING] [MODEL_TRAINING]
batch_size=16 batch_size=32
epoch=30 epoch=30
seed_val=42 seed_val=42
learning_rate=2e-4 learning_rate=2e-4
...@@ -257,7 +257,8 @@ class Linker(Module): ...@@ -257,7 +257,8 @@ class Linker(Module):
logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
# Run the kinker on the categories predictions # Run the kinker on the categories predictions
logits_predictions = self(batch_atoms, batch_pos_idx, batch_neg_idx, sentences_embedding, batch_sentences_mask) logits_predictions = self(batch_atoms, batch_pos_idx, batch_neg_idx, sentences_embedding,
batch_sentences_mask)
linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links) linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
# Perform a backward pass to calculate the gradients. # Perform a backward pass to calculate the gradients.
...@@ -366,9 +367,10 @@ class Linker(Module): ...@@ -366,9 +367,10 @@ class Linker(Module):
:param positional_ids: :param positional_ids:
A List of batch_size elements, each being a List of num_atoms LongTensors. A List of batch_size elements, each being a List of num_atoms LongTensors.
Each LongTensor in positional_ids[b][a] indexes the location of atoms of type a in sentence b. Each LongTensor in positional_ids[b][a] indexes the location of atoms of type a in sentence b.
:param device: :param atom_type:
:return: :return:
""" """
return [[bsd_tensor.select(0, index=i).index_select(0, index=atom.to(self.device)) for atom in sentence] return torch.stack([torch.stack([bsd_tensor.select(0, index=i).select(0,index=int(atom))
for i, sentence in enumerate(positional_ids[atom_map_redux[atom_type]])] if atom != -1 else torch.zeros(self.dim_embedding_atoms) for atom in sentence])
\ No newline at end of file for i, sentence in enumerate(positional_ids[:, atom_map_redux[atom_type], :])])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment