-
Julien Breton authoredJulien Breton authored
main.py 1.03 KiB
from transformers import BertTokenizer, BertModel
import torch
# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
# Encode input text
inputs = tokenizer("Example sentence for BERT attention visualization.", return_tensors="pt")
# Forward pass, get attentions
outputs = model(**inputs)
attentions = outputs.attentions # Tuple of attention tensors for each layer
# Get the token index for a word of interest, e.g., "attention"
token_id = tokenizer.convert_tokens_to_ids("attention")
# Find the positions of this token in the input sequence
token_positions = (inputs['input_ids'][0] == token_id).nonzero(as_tuple=True)[0]
# Access the attention from one of these positions, e.g., first layer, first head
attention_layer_head = attentions[0][0, :, token_positions[0], :]
# Now `attention_layer_head` contains the attention weights from the word "attention" to all other tokens in this specific layer and head
print(attention_layer_head)