m-ric HF staff commited on
Commit
098d586
Β·
1 Parent(s): 1ebe266

Add nice slider and visualization

Browse files
Files changed (2) hide show
  1. app.py +115 -16
  2. requirements.txt +2 -1
app.py CHANGED
@@ -3,6 +3,11 @@ from transformers import AutoTokenizer
3
  from lxt.models.llama import LlamaForCausalLM, attnlrp
4
  from lxt.utils import clean_tokens
5
  import gradio as gr
 
 
 
 
 
6
 
7
  # Load model and tokenizer
8
  model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.bfloat16, device_map="cuda")
@@ -11,44 +16,138 @@ tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
11
  # Apply AttnLRP rules
12
  attnlrp.register(model)
13
 
14
- def generate_and_visualize(prompt):
 
 
 
 
 
15
  input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device)
16
  input_embeds = model.get_input_embeddings()(input_ids)
17
 
18
- output_logits = model(inputs_embeds=input_embeds.requires_grad_(), use_cache=False).logits
19
- max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- max_logits.backward(max_logits)
22
- relevance = input_embeds.grad.float().sum(-1).cpu()[0]
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Normalize relevance between [0, 1] for highlighting
25
  relevance = (relevance - relevance.min()) / (relevance.max() - relevance.min())
26
 
27
- # Remove '_' characters from token strings
28
- tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
29
- tokens = clean_tokens(tokens)
30
- tokens = [token.replace("_", " ").replace("▁", " ").replace("<s>", "") for token in tokens]
31
- print("Decoded: ", tokenizer.decode(input_ids[0]))
32
-
33
  # Create list of (token, score) tuples for HighlightedText
34
- highlighted_tokens = [(token, max(float(score), 0.05)-0.05) for token, score in zip(tokens, relevance)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- return highlighted_tokens
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # Define Gradio interface
39
  with gr.Blocks() as demo:
40
- gr.Markdown("""# Attention Visualization Demo
41
 
42
  This demo uses the library [LXT](https://lxt.readthedocs.io/en/latest/quickstart.html#tinyllama) under the hood.""")
43
 
44
  input_text = gr.Textbox(label="Input Prompt", lines=5, value="""\
45
  Context: Mount Everest attracts many climbers, including highly experienced mountaineers. There are two main climbing routes, one approaching the summit from the southeast in Nepal (known as the standard route) and the other from the north in Tibet. While not posing substantial technical climbing challenges on the standard route, Everest presents dangers such as altitude sickness, weather, and wind, as well as hazards from avalanches and the Khumbu Icefall. As of November 2022, 310 people have died on Everest. Over 200 bodies remain on the mountain and have not been removed due to the dangerous conditions. The first recorded efforts to reach Everest's summit were made by British mountaineers. As Nepal did not allow foreigners to enter the country at the time, the British made several attempts on the north ridge route from the Tibetan side. After the first reconnaissance expedition by the British in 1921 reached 7,000 m (22,970 ft) on the North Col, the 1922 expedition pushed the north ridge route up to 8,320 m (27,300 ft), marking the first time a human had climbed above 8,000 m (26,247 ft). The 1924 expedition resulted in one of the greatest mysteries on Everest to this day: George Mallory and Andrew Irvine made a final summit attempt on 8 June but never returned, sparking debate as to whether they were the first to reach the top. Tenzing Norgay and Edmund Hillary made the first documented ascent of Everest in 1953, using the southeast ridge route. Norgay had reached 8,595 m (28,199 ft) the previous year as a member of the 1952 Swiss expedition. The Chinese mountaineering team of Wang Fuzhou, Gonpo, and Qu Yinhua made the first reported ascent of the peak from the north ridge on 25 May 1960. \
46
  Question: How high did they climb in 1922? According to the text, the 1922 expedition reached 8,""")
 
47
  generate_button = gr.Button("Generate and Visualize")
48
 
49
- output = gr.HighlightedText(label="Attention Visualization", adjacent_separator="", combine_adjacent=True)
 
 
 
 
 
 
 
 
 
50
 
51
- generate_button.click(generate_and_visualize, inputs=input_text, outputs=output)
 
 
 
 
 
 
 
 
 
 
52
 
53
  # Launch the demo
54
  if __name__ == "__main__":
 
3
  from lxt.models.llama import LlamaForCausalLM, attnlrp
4
  from lxt.utils import clean_tokens
5
  import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ import numpy as np
9
+ import io
10
+ from PIL import Image
11
 
12
  # Load model and tokenizer
13
  model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.bfloat16, device_map="cuda")
 
16
  # Apply AttnLRP rules
17
  attnlrp.register(model)
18
 
19
+ def really_clean_tokens(tokens):
20
+ tokens = clean_tokens(tokens)
21
+ tokens = [token.replace("_", " ").replace("▁", " ").replace("<s>", "") for token in tokens]
22
+ return tokens
23
+
24
+ def generate_and_visualize(prompt, num_tokens=10):
25
  input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device)
26
  input_embeds = model.get_input_embeddings()(input_ids)
27
 
28
+ generated_tokens_ids = []
29
+ all_relevances = []
30
+
31
+ for _ in range(num_tokens):
32
+ output_logits = model(inputs_embeds=input_embeds.requires_grad_(), use_cache=False).logits
33
+ max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)
34
+
35
+ max_logits.backward(max_logits)
36
+ relevance = input_embeds.grad.float().sum(-1).cpu()[0]
37
+ all_relevances.append(relevance)
38
+
39
+ # Generate next token
40
+ next_token = max_indices.unsqueeze(0)
41
+ generated_tokens_ids.append(next_token.item())
42
 
43
+ # Prepare for next iteration
44
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
45
+ input_embeds = model.get_input_embeddings()(input_ids)
46
 
47
+ # Process tokens and relevances
48
+ input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
49
+
50
+ input_tokens = really_clean_tokens(input_tokens)
51
+ generated_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(generated_tokens_ids))
52
+ generated_text = tokenizer.decode(generated_tokens_ids)
53
+ print(f"Generated text: {generated_text}")
54
+
55
+ return input_tokens, all_relevances, generated_text, len(generated_tokens_ids), generated_tokens
56
+
57
+ def update_visualization_step(input_tokens, all_relevances, output_tokens, step):
58
+ relevance = all_relevances[step]
59
+
60
  # Normalize relevance between [0, 1] for highlighting
61
  relevance = (relevance - relevance.min()) / (relevance.max() - relevance.min())
62
 
 
 
 
 
 
 
63
  # Create list of (token, score) tuples for HighlightedText
64
+ highlighted_tokens = [(token, max(float(score), 0.05)-0.05) for token, score in zip(input_tokens, relevance)]
65
+
66
+ return highlighted_tokens, output_tokens[:step+1]
67
+
68
+ def generate_heatmap(input_tokens, all_relevances, output_tokens):
69
+ # Create a matrix of attention scores
70
+ attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])
71
+
72
+ # Find columns with non-zero values (including adjacent columns)
73
+ non_zero_cols = np.where(np.abs(attention_matrix).sum(axis=0) > 1.)[0]
74
+ for col in range(5):
75
+ non_zero_cols = np.union1d(non_zero_cols, non_zero_cols + col)
76
+ non_zero_cols = np.union1d(non_zero_cols, non_zero_cols - col)
77
+ non_zero_cols = np.sort(non_zero_cols)
78
+ non_zero_cols = non_zero_cols[(non_zero_cols >= 0) & (non_zero_cols < attention_matrix.shape[1])]
79
+
80
+
81
+ # Filter the matrix and input tokens
82
+ filtered_matrix = attention_matrix[:, non_zero_cols]
83
+ filtered_input_tokens = [input_tokens[i] for i in non_zero_cols]
84
+
85
+ # Create the heatmap
86
+ plt.figure(figsize=(20, 7))
87
+ sns.heatmap(filtered_matrix, xticklabels=filtered_input_tokens, yticklabels=output_tokens, cmap="YlOrRd")
88
+ plt.title("Attention Heatmap (shows only input tokens where there was strong enough attention)")
89
+ plt.xlabel("Input Tokens")
90
+ plt.ylabel("Output Tokens")
91
+ plt.xticks(rotation=90)
92
+ plt.yticks(rotation=0)
93
+
94
+ # Save the plot to a bytes buffer
95
+ buf = io.BytesIO()
96
+ plt.savefig(buf, format='png', bbox_inches='tight')
97
+ buf.seek(0)
98
+ plt.close()
99
+
100
+ # Return the image as a numpy array
101
+ return np.array(Image.open(buf))
102
 
103
+ def on_generate(input_text, num_tokens):
104
+ input_tokens, relevances, generated_text, generated_length, output_tokens = generate_and_visualize(input_text, num_tokens)
105
+ heatmap = generate_heatmap(input_tokens, relevances, output_tokens)
106
+ return (
107
+ input_tokens,
108
+ relevances,
109
+ update_visualization_step(input_tokens, relevances, output_tokens, 0)[0],
110
+ output_tokens,
111
+ output_tokens[:1],
112
+ generated_text,
113
+ gr.Slider(maximum=generated_length-1, value=0),
114
+ heatmap
115
+ )
116
 
117
  # Define Gradio interface
118
  with gr.Blocks() as demo:
119
+ gr.Markdown("""# Extended Attention Visualization Demo
120
 
121
  This demo uses the library [LXT](https://lxt.readthedocs.io/en/latest/quickstart.html#tinyllama) under the hood.""")
122
 
123
  input_text = gr.Textbox(label="Input Prompt", lines=5, value="""\
124
  Context: Mount Everest attracts many climbers, including highly experienced mountaineers. There are two main climbing routes, one approaching the summit from the southeast in Nepal (known as the standard route) and the other from the north in Tibet. While not posing substantial technical climbing challenges on the standard route, Everest presents dangers such as altitude sickness, weather, and wind, as well as hazards from avalanches and the Khumbu Icefall. As of November 2022, 310 people have died on Everest. Over 200 bodies remain on the mountain and have not been removed due to the dangerous conditions. The first recorded efforts to reach Everest's summit were made by British mountaineers. As Nepal did not allow foreigners to enter the country at the time, the British made several attempts on the north ridge route from the Tibetan side. After the first reconnaissance expedition by the British in 1921 reached 7,000 m (22,970 ft) on the North Col, the 1922 expedition pushed the north ridge route up to 8,320 m (27,300 ft), marking the first time a human had climbed above 8,000 m (26,247 ft). The 1924 expedition resulted in one of the greatest mysteries on Everest to this day: George Mallory and Andrew Irvine made a final summit attempt on 8 June but never returned, sparking debate as to whether they were the first to reach the top. Tenzing Norgay and Edmund Hillary made the first documented ascent of Everest in 1953, using the southeast ridge route. Norgay had reached 8,595 m (28,199 ft) the previous year as a member of the 1952 Swiss expedition. The Chinese mountaineering team of Wang Fuzhou, Gonpo, and Qu Yinhua made the first reported ascent of the peak from the north ridge on 25 May 1960. \
125
  Question: How high did they climb in 1922? According to the text, the 1922 expedition reached 8,""")
126
+ num_tokens = gr.Slider(minimum=1, maximum=50, step=1, value=10, label="Number of tokens to generate")
127
  generate_button = gr.Button("Generate and Visualize")
128
 
129
+ generated_output = gr.Textbox(label="Generated Text")
130
+ heatmap_output = gr.Image(label="Attention Heatmap")
131
+
132
+ step_slider = gr.Slider(minimum=0, maximum=1, step=1, value=0, label="Visualization Step")
133
+ attention_on_inputs = gr.HighlightedText(label="Attention Visualization", adjacent_separator="", combine_adjacent=True)
134
+ current_tokens = gr.Textbox(label="Current Token")
135
+
136
+ input_tokens_state = gr.State([])
137
+ output_tokens_state = gr.State([])
138
+ relevances_state = gr.State([])
139
 
140
+ generate_button.click(
141
+ on_generate,
142
+ inputs=[input_text, num_tokens],
143
+ outputs=[input_tokens_state, relevances_state, attention_on_inputs, output_tokens_state, current_tokens, generated_output, step_slider, heatmap_output]
144
+ )
145
+
146
+ step_slider.change(
147
+ update_visualization_step,
148
+ inputs=[input_tokens_state, relevances_state, output_tokens_state, step_slider],
149
+ outputs=[attention_on_inputs, current_tokens]
150
+ )
151
 
152
  # Launch the demo
153
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  accelerate
2
- lxt
 
 
1
  accelerate
2
+ lxt
3
+ seaborn