diff --git a/Linker/Linker.py b/Linker/Linker.py index f76404fc7566c31f94eb70b2930e5d60394582e1..dc3e6ee63b469db642ead6bbb8f0520e617652a7 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -371,6 +371,7 @@ class Linker(Module): :return: """ - return torch.stack([torch.stack([bsd_tensor.select(0, index=i).select(0,index=int(atom)) - if atom != -1 else torch.zeros(self.dim_embedding_atoms) for atom in sentence]) - for i, sentence in enumerate(positional_ids[:, atom_map_redux[atom_type], :])]) + return torch.stack([torch.stack([bsd_tensor.select(0, index=i).select(0, index=int(atom)).to(self.device) + if atom != -1 else torch.zeros(self.dim_embedding_atoms, device=self.device) + for atom in sentence]) + for i, sentence in enumerate(positional_ids[:, atom_map_redux[atom_type], :])])