Skip to content
Snippets Groups Projects
Commit eff2f26e authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

update utils

parent 140ce64e
No related branches found
No related tags found
3 merge requests!6Linker with transformer,!5Linker with transformer,!2Change preprocess
......@@ -26,7 +26,7 @@ dropout=0.1
sinkhorn_iters=3
[MODEL_TRAINING]
batch_size=16
batch_size=32
epoch=30
seed_val=42
learning_rate=2e-4
......@@ -257,7 +257,8 @@ class Linker(Module):
logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
# Run the kinker on the categories predictions
logits_predictions = self(batch_atoms, batch_pos_idx, batch_neg_idx, sentences_embedding, batch_sentences_mask)
logits_predictions = self(batch_atoms, batch_pos_idx, batch_neg_idx, sentences_embedding,
batch_sentences_mask)
linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
# Perform a backward pass to calculate the gradients.
......@@ -366,9 +367,10 @@ class Linker(Module):
:param positional_ids:
A List of batch_size elements, each being a List of num_atoms LongTensors.
Each LongTensor in positional_ids[b][a] indexes the location of atoms of type a in sentence b.
:param device:
:param atom_type:
:return:
"""
return [[bsd_tensor.select(0, index=i).index_select(0, index=atom.to(self.device)) for atom in sentence]
for i, sentence in enumerate(positional_ids[atom_map_redux[atom_type]])]
\ No newline at end of file
return torch.stack([torch.stack([bsd_tensor.select(0, index=i).select(0,index=int(atom))
if atom != -1 else torch.zeros(self.dim_embedding_atoms) for atom in sentence])
for i, sentence in enumerate(positional_ids[:, atom_map_redux[atom_type], :])])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment