Spaces:
Runtime error
Runtime error
File size: 10,005 Bytes
c7c30cc a6e310c 098d586 c7c30cc a6e310c d4bacf6 a6e310c 098d586 a6e310c 098d586 a6e310c 098d586 a6e310c 098d586 a6e310c 098d586 a6e310c 098d586 45c11ea 098d586 a6e310c c7c30cc a6e310c c7c30cc 1ebe266 c7c30cc a6e310c c7c30cc 098d586 1ebe266 a6e310c 098d586 a6e310c c7c30cc 098d586 a6e310c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import io
import torch
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from PIL import Image
from transformers import AutoTokenizer
from lxt.models.llama import LlamaForCausalLM, attnlrp
from lxt.utils import clean_tokens
# Load model and tokenizer
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
# Apply AttnLRP rules
attnlrp.register(model)
def really_clean_tokens(tokens):
tokens = clean_tokens(tokens)
tokens = [token.replace("_", " ").replace("β", " ").replace("<s>", "") for token in tokens]
return tokens
def generate_and_visualize(prompt, num_tokens=10):
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device)
input_embeds = model.get_input_embeddings()(input_ids)
generated_tokens_ids = []
all_relevances = []
for _ in range(num_tokens):
output_logits = model(inputs_embeds=input_embeds.requires_grad_(), use_cache=False).logits
max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)
max_logits.backward(max_logits)
relevance = input_embeds.grad.float().sum(-1).cpu()[0]
all_relevances.append(relevance)
# Generate next token
next_token = max_indices.unsqueeze(0)
generated_tokens_ids.append(next_token.item())
# Prepare for next iteration
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
input_embeds = model.get_input_embeddings()(input_ids)
# Process tokens and relevances
input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
input_tokens = really_clean_tokens(input_tokens)
generated_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(generated_tokens_ids))
generated_text = tokenizer.decode(generated_tokens_ids)
print(f"Generated text: {generated_text}")
return input_tokens, all_relevances, generated_text, len(generated_tokens_ids), generated_tokens
def update_visualization_step(input_tokens, all_relevances, output_tokens, step):
relevance = all_relevances[step]
# Normalize relevance between [0, 1] for highlighting
relevance = (relevance - relevance.min()) / (relevance.max() - relevance.min())
# Create list of (token, score) tuples for HighlightedText
highlighted_tokens = [(token, max(float(score), 0.05)-0.05) for token, score in zip(input_tokens, relevance)]
return highlighted_tokens, output_tokens[:step+1]
def generate_heatmap(input_tokens, all_relevances, output_tokens):
# Create a matrix of attention scores
attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])
# Find columns with non-zero values (including adjacent columns)
non_zero_cols = np.where(np.abs(attention_matrix).sum(axis=0) > 1.)[0]
for col in range(5):
non_zero_cols = np.union1d(non_zero_cols, non_zero_cols + col)
non_zero_cols = np.union1d(non_zero_cols, non_zero_cols - col)
non_zero_cols = np.sort(non_zero_cols)
non_zero_cols = non_zero_cols[(non_zero_cols >= 0) & (non_zero_cols < attention_matrix.shape[1])]
# Filter the matrix and input tokens
filtered_matrix = attention_matrix[:, non_zero_cols]
filtered_input_tokens = [input_tokens[i] for i in non_zero_cols]
# Create the heatmap
plt.figure(figsize=(20, 7))
sns.heatmap(filtered_matrix, xticklabels=filtered_input_tokens, yticklabels=output_tokens, cmap="YlOrRd")
plt.title("Attention Heatmap (shows only input tokens where there was strong enough attention)")
plt.xlabel("Input Tokens")
plt.ylabel("Output Tokens")
plt.xticks(rotation=90)
plt.yticks(rotation=0)
# Save the plot to a bytes buffer
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
plt.close()
# Return the image as a numpy array
return np.array(Image.open(buf))
def on_generate(input_text, num_tokens):
input_tokens, relevances, generated_text, generated_length, output_tokens = generate_and_visualize(input_text, num_tokens)
heatmap = generate_heatmap(input_tokens, relevances.fillna(0), output_tokens)
return (
input_tokens,
relevances,
update_visualization_step(input_tokens, relevances, output_tokens, 0)[0],
output_tokens,
output_tokens[:1],
generated_text,
gr.Slider(maximum=generated_length-1, value=0),
heatmap
)
# Define example inputs
examples = [
[
"""Context: Mount Everest attracts many climbers, including highly experienced mountaineers.
There are two main climbing routes, one approaching the summit from the southeast in Nepal (known as the standard route) and the other from the north in Tibet. While not posing substantial technical climbing challenges on the standard route, Everest presents dangers such as altitude sickness, weather, and wind, as well as hazards from avalanches and the Khumbu Icefall. As of November 2022, 310 people have died on Everest.
Over 200 bodies remain on the mountain and have not been removed due to the dangerous conditions. The first recorded efforts to reach Everest's summit were made by British mountaineers.
As Nepal did not allow foreigners to enter the country at the time, the British made several attempts on the north ridge route from the Tibetan side. After the first reconnaissance expedition by the British in 1921 reached 7,000 m (22,970 ft) on the North Col, the 1922 expedition pushed the north ridge route up to 8,320 m (27,300 ft), marking the first time a human had climbed above 8,000 m (26,247 ft).
The 1924 expedition resulted in one of the greatest mysteries on Everest to this day: George Mallory and Andrew Irvine made a final summit attempt on 8 June but never returned, sparking debate as to whether they were the first to reach the top.
Tenzing Norgay and Edmund Hillary made the first documented ascent of Everest in 1953, using the southeast ridge route. Norgay had reached 8,595 m (28,199 ft) the previous year as a member of the 1952 Swiss expedition. The Chinese mountaineering team of Wang Fuzhou, Gonpo, and Qu Yinhua made the first reported ascent of the peak from the north ridge on 25 May 1960. \nQuestion: How high did they climb in 1922? According to the text, the 1922 expedition reached 8,""",
10
],
[
"""Hurricane Katrina killed hundreds of people as it made landfall on New Orleans in 2005 - many of these deaths could have been avoided if alerts had been given one day earlier. Accurate weather forecasts are really life-saving.
π₯ Now, NASA and IBM just dropped a game-changing new model: the first ever foundation model for weather! This means, it's the first time we have a generalist model not restricted to one task, but able to predict 160 weather variables!
Prithvi WxC (Prithvi, βΰ€ͺΰ₯ΰ€₯ΰ₯ΰ€΅ΰ₯β, is the Sanskrit name for Earth) - is a 2.3 billion parameter model, with an architecture close to previous vision transformers like Hiera.
π‘ But it comes with some important tweaks: under the hood, Prithvi WxC uses a clever transformer-based architecture with 25 encoder and 5 decoder blocks. It alternates between "local" and "global" attention to capture both regional and global weather patterns. How many weather variables can Prithvi predict? Prithvi can""",
15
],
[
"""Transformers v4.45.0 released: includes a lightning-fast method to build tools! β‘οΈ
During user research with colleagues @MoritzLaurer and @Jofthomas , we discovered that the class definition currently in used to define a Tool in
transformers.agents is a bit tedious to use, because it goes in great detail.
β‘οΈ So Iβve made an easier way to build tools: just make a function with type hints + a docstring, and add a @tool decorator in front.
β
VoilΓ , youβre good to go!
How can you build tools simply in transformers? Just use the decorator""",
20
]
]
# Define Gradio interface
with gr.Blocks() as demo:
gr.Markdown("""# Attribution Visualization Demo
This demo uses the library [LXT](https://lxt.readthedocs.io/en/latest/quickstart.html#tinyllama) to attribute the output tokens to some input tokens.""")
input_text = gr.Textbox(label="Input Prompt", lines=5, value=examples[0][0])
num_tokens = gr.Slider(minimum=1, maximum=50, step=1, value=10, label="Number of tokens to generate")
generate_button = gr.Button("Generate and Visualize")
generated_output = gr.Textbox(label="Generated Text")
heatmap_output = gr.Image(label="Attention Heatmap")
step_slider = gr.Slider(minimum=0, maximum=1, step=1, value=0, label="Visualization Step")
attention_on_inputs = gr.HighlightedText(label="Attention Visualization", adjacent_separator="", combine_adjacent=True)
current_tokens = gr.Textbox(label="Current Token")
input_tokens_state = gr.State([])
output_tokens_state = gr.State([])
relevances_state = gr.State([])
gr.Examples(
examples=examples,
inputs=[input_text, num_tokens],
outputs=[input_tokens_state, relevances_state, attention_on_inputs, output_tokens_state, current_tokens, generated_output, step_slider, heatmap_output],
fn=on_generate,
cache_examples=True
)
generate_button.click(
on_generate,
inputs=[input_text, num_tokens],
outputs=[input_tokens_state, relevances_state, attention_on_inputs, output_tokens_state, current_tokens, generated_output, step_slider, heatmap_output]
)
step_slider.change(
update_visualization_step,
inputs=[input_tokens_state, relevances_state, output_tokens_state, step_slider],
outputs=[attention_on_inputs, current_tokens]
)
# Launch the demo
if __name__ == "__main__":
demo.launch() |