monet9736 commited on
Commit
dd0c4c7
·
verified ·
1 Parent(s): 42b5f5f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -0
app.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import glob
4
+ import io
5
+ import os
6
+ import random
7
+ import struct
8
+ from contextlib import contextmanager
9
+ from html import escape
10
+
11
+ import msgpack
12
+ import streamlit as st
13
+ import torch
14
+ import tqdm
15
+ from huggingface_hub import HfFileSystem
16
+ from transformers import AutoTokenizer
17
+
18
+ st.set_page_config(layout="wide")
19
+
20
+ MODEL_NAME = os.environ.get("MODEL_NAME", "MonetLLM/monet-vd-1.4B-100BT-hf")
21
+ CONTEXT_WINDOW = int(os.environ.get("CONTEXT_WINDOW", "12"))
22
+ CANDIDATE_THRESHOLD = int(os.environ.get("CANDIDATE_THRESHOLD", "50"))
23
+
24
+ HORIZONTAL_STYLE = """<style class="hide-element">
25
+ /* Hides the style container and removes the extra spacing */
26
+ .element-container:has(.hide-element) {
27
+ display: none;
28
+ }
29
+ /*
30
+ The selector for >.element-container is necessary to avoid selecting the whole
31
+ body of the streamlit app, which is also a stVerticalBlock.
32
+ */
33
+ div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) {
34
+ display: flex;
35
+ flex-direction: row !important;
36
+ flex-wrap: wrap;
37
+ gap: 0.5rem;
38
+ align-items: baseline;
39
+ }
40
+ /* Buttons and their parent container all have a width of 704px, which we need to override */
41
+ div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) div {
42
+ width: max-content !important;
43
+ }
44
+ /* Just an example of how you would style buttons, if desired */
45
+ /*
46
+ div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) button {
47
+ border-color: red;
48
+ }
49
+ */
50
+ </style>"""
51
+
52
+
53
+ @st.cache_resource
54
+ def prepare_routing_resources():
55
+ fs = HfFileSystem()
56
+ for filename in fs.glob(f"datasets/{MODEL_NAME}-viewer-data/*"):
57
+ if not os.path.exists(os.path.basename(filename)):
58
+ print(f"[*] Download {filename}...")
59
+ fs.download(filename, ".")
60
+
61
+ input_tokens = torch.load("inputs.pt")
62
+
63
+ examples_tables = []
64
+ for i in tqdm.trange(len(glob.glob("examples-*.msgpack"))):
65
+ with open(f"examples-{i}.msgpack", "rb") as fp:
66
+ fp.seek(-4, io.SEEK_END)
67
+ table_size = struct.unpack(">I", fp.read(4))[0]
68
+
69
+ fp.seek(-(table_size + 4), io.SEEK_END)
70
+ examples_tables.append(msgpack.Unpacker(fp).unpack())
71
+
72
+ candidates = []
73
+ for i, table in enumerate(tqdm.tqdm(examples_tables)):
74
+ candidates.append([])
75
+ with open(f"examples-{i}.msgpack", "rb") as fp:
76
+ unpacker = msgpack.Unpacker(fp)
77
+ for j in range(len(table)):
78
+ if len(unpacker.unpack()) > CANDIDATE_THRESHOLD:
79
+ candidates[-1].append(j)
80
+
81
+ routing_tables = []
82
+ for i in tqdm.trange(len(examples_tables)):
83
+ with open(f"routings-{i}.msgpack", "rb") as fp:
84
+ fp.seek(-4, io.SEEK_END)
85
+ table_size = struct.unpack(">I", fp.read(4))[0]
86
+
87
+ fp.seek(-(table_size + 4), io.SEEK_END)
88
+ routing_tables.append(msgpack.Unpacker(fp).unpack())
89
+
90
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
91
+ return input_tokens, examples_tables, routing_tables, candidates, tokenizer
92
+
93
+
94
+ input_tokens, examples_tables, routing_tables, candidates, tokenizer = (
95
+ prepare_routing_resources()
96
+ )
97
+
98
+
99
+ def render_routing_examples_in_html(router_index: int, expert_id: int) -> str:
100
+ with open(f"examples-{router_index}.msgpack", "rb") as fp:
101
+ fp.seek(examples_tables[router_index][expert_id])
102
+ examples = msgpack.Unpacker(fp).unpack()
103
+ with open(f"routings-{router_index}.msgpack", "rb") as fp:
104
+ table = []
105
+ for i, j, _ in examples:
106
+ start = max(j - CONTEXT_WINDOW, 0)
107
+ end = min(j + CONTEXT_WINDOW, len(routing_tables[router_index][i]))
108
+
109
+ fp.seek(routing_tables[router_index][i][start])
110
+ unpacker = msgpack.Unpacker(fp, strict_map_key=False)
111
+ activated = [unpacker.unpack().get(expert_id, 0) for _ in range(start, end)]
112
+
113
+ full_text = tokenizer.decode(input_tokens[i])
114
+ encodings = tokenizer(full_text, add_special_tokens=False)
115
+ offset = len(encodings.input_ids) - input_tokens.size(1)
116
+
117
+ spans, lslice = [], None
118
+ for k in range(start, end):
119
+ if offset + k >= 0 and (sslice := encodings.token_to_chars(offset + k)):
120
+ span, score = full_text[slice(*sslice)], activated[k - start]
121
+ if lslice == sslice:
122
+ score = max(spans.pop(-1)[1], score)
123
+ spans.append((escape(span), score))
124
+ lslice = sslice
125
+
126
+ spans = [
127
+ f"<span style='background-color: rgba(144, 238, 144, {score}' title='Routing: {score*100:.2f}%'>{span}</span>"
128
+ for span, score in spans
129
+ ]
130
+ table.append(
131
+ f"""
132
+ <tr>
133
+ <td align='right'>
134
+ <span style='font-weight: bold'>
135
+ {escape(tokenizer.decode(input_tokens[i, j]))} ({activated[j - start] * 100:.2f}%)
136
+ </span>
137
+ </td>
138
+ <td align='left'>
139
+ (...) {"".join(spans)} (...)
140
+ </td>
141
+ <td align='right'>
142
+ ({i}, {j})
143
+ </td>
144
+ </tr>
145
+ """
146
+ )
147
+
148
+ return f"""
149
+ <div style='background-color: white; color: black; padding: 1em 3em; font-size: 12pt'>
150
+ <h2 style='font-size: 18pt'> Activated Examples of Group {router_index} / Expert {expert_id} </h2>
151
+ <table>
152
+ {"".join(table)}
153
+ </table>
154
+ </div>
155
+ """
156
+
157
+
158
+ @contextmanager
159
+ def st_horizontal():
160
+ st.markdown(HORIZONTAL_STYLE, unsafe_allow_html=True)
161
+ with st.container():
162
+ st.markdown(
163
+ '<span class="hide-element horizontal-marker"></span>',
164
+ unsafe_allow_html=True,
165
+ )
166
+ yield
167
+
168
+
169
+ col1, col2 = st.columns(2)
170
+ with col1:
171
+ router_groups = [f"Routing Group {i}" for i in range(len(examples_tables))]
172
+ router_index = st.selectbox("Expert Routing Group", router_groups, index=4)
173
+ with col2:
174
+ expert_id = st.number_input("Expert Index", 0, len(examples_tables[0]), 54136)
175
+
176
+ with st_horizontal():
177
+ show_btn = st.button("Show")
178
+ random_btn = st.button("Random")
179
+
180
+ if show_btn or random_btn:
181
+ router_index = router_groups.index(router_index)
182
+ if random_btn:
183
+ expert_id = random.choice(candidates[router_index])
184
+ st.html(render_routing_examples_in_html(router_index, expert_id))