diff --git a/Linker/PositionEncoding.py b/Linker/PositionEncoding.py index d0d6524c6d948927e8a99419d415dbf3c07d61bf..5389a7a17bf6cb8c2866b3ba80f6e9bd5eff63ce 100644 --- a/Linker/PositionEncoding.py +++ b/Linker/PositionEncoding.py @@ -21,5 +21,5 @@ class PositionalEncoding(nn.Module): Args: x: Tensor, shape [batch_size,seq_len, embedding_dim] """ - x = x + self.pe[:x.size(0)] + x = x + self.pe[:, :x.size(1)] return self.dropout(x)