Skip to content
Snippets Groups Projects
main.py 1.03 KiB
from transformers import BertTokenizer, BertModel
import torch

# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)

# Encode input text
inputs = tokenizer("Example sentence for BERT attention visualization.", return_tensors="pt")

# Forward pass, get attentions
outputs = model(**inputs)
attentions = outputs.attentions  # Tuple of attention tensors for each layer

# Get the token index for a word of interest, e.g., "attention"
token_id = tokenizer.convert_tokens_to_ids("attention")

# Find the positions of this token in the input sequence
token_positions = (inputs['input_ids'][0] == token_id).nonzero(as_tuple=True)[0]

# Access the attention from one of these positions, e.g., first layer, first head
attention_layer_head = attentions[0][0, :, token_positions[0], :]

# Now `attention_layer_head` contains the attention weights from the word "attention" to all other tokens in this specific layer and head
print(attention_layer_head)