luca-peric commited on
Commit
f2f927b
·
1 Parent(s): ad774a9

cleaning things up via gemini 2.5 pro

Browse files
Files changed (1) hide show
  1. app.py +420 -330
app.py CHANGED
@@ -3,420 +3,510 @@ import gradio as gr
3
  import torch
4
  import itertools # For color cycling
5
  import tiktoken # For GPT-4 tokenizer
6
- from transformers import AutoTokenizer # For Llama3 tokenizer - AutoModel usually not needed just for tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # Bytelatent imports (assuming they are in the python path)
 
 
9
  try:
10
  from bytelatent.data.file_util import get_fs
11
  from bytelatent.generate_patcher import patcher_nocache
12
  from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
13
  from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies
14
  from bytelatent.args import TrainArgs
15
- from download_blt_weights import main as ensure_present
16
- BLT_AVAILABLE = True
 
17
  except ImportError as e:
18
- print(f"Warning: Bytelatent libraries not found. Bytelatent functionality will be disabled. Error: {e}")
19
- BLT_AVAILABLE = False
20
  # Define dummy classes/functions if BLT is not available to avoid NameErrors later
21
  class BltTokenizer: pass
22
  class TrainArgs: pass
23
  def patcher_nocache(*args, **kwargs): return None
24
  def plot_entropies(*args, **kwargs): return None
25
  def ensure_present(*args, **kwargs): pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
27
 
28
- # --- Global Setup ---
29
-
30
- # Define colors for patches/tokens
31
- VIZ_COLORS = [
32
- "#a6cee3", "#1f78b4", "#b2df8a", "#33a02c", "#fb9a99", "#e31a1c",
33
- "#fdbf6f", "#ff7f00", "#cab2d6", "#6a3d9a", "#ffff99", "#b15928"
34
- ] # Add more if you expect many segments
35
-
36
- LLAMA3_MODEL_NAME = "meta-llama/Meta-Llama-3-8B" # Or choose another variant like Instruct
37
-
38
- # --- Helper Functions ---
39
-
40
- def create_bytelatent_highlight_data(tokenizer, patch_lengths_tensor, tokens_tensor, colors):
41
- """Generates data for gr.HighlightedText based on bytelatent patches."""
42
- if not BLT_AVAILABLE:
43
- return [("Bytelatent library not available.", "Error")]
44
- if patch_lengths_tensor is None or tokens_tensor is None or patch_lengths_tensor.numel() == 0:
45
- return None
46
- patch_lengths = patch_lengths_tensor.tolist()
47
- all_tokens = tokens_tensor.tolist()
48
- highlighted_data = []
49
- current_token_index = 0
50
- patch_count = 0 # Initialize patch count
51
- # color_cycler = itertools.cycle(colors) # Moved inside loop if needed per-patch
52
- for i, length in enumerate(patch_lengths):
53
- if length <= 0: continue
54
- patch_token_ids = all_tokens[current_token_index : current_token_index + length]
55
- if not patch_token_ids: continue
56
- try: patch_text = tokenizer.decode(patch_token_ids)
57
- except Exception as decode_err:
58
- print(f"Warning: Bytelatent patch decoding failed: {decode_err}")
59
- patch_text = f"[Decode Error: {len(patch_token_ids)} tokens]"
60
- patch_label = f"BL Patch {i+1}"
61
- highlighted_data.append((patch_text, patch_label))
62
- patch_count += 1 # Increment count for each valid patch added
63
- current_token_index += length
64
-
65
- # Handle remainder separately, don't count it as a 'patch'
66
- if current_token_index != len(all_tokens):
67
- print(f"Warning: Bytelatent token mismatch. Consumed {current_token_index}, total {len(all_tokens)}")
68
- remaining_tokens = all_tokens[current_token_index:]
69
- if remaining_tokens:
70
- try: remaining_text = tokenizer.decode(remaining_tokens)
71
- except Exception: remaining_text = f"[Decode Error: {len(remaining_tokens)} remaining tokens]"
72
- highlighted_data.append((remaining_text, "BL Remainder"))
73
-
74
- # Return both highlighted data and the calculated patch count
75
- return highlighted_data, patch_count
76
-
77
-
78
- def create_tiktoken_highlight_data(prompt, colors):
79
- """Generates data for gr.HighlightedText based on tiktoken (gpt-4) tokens."""
80
  try:
81
- enc = tiktoken.get_encoding("cl100k_base")
82
- tiktoken_ids = enc.encode(prompt)
83
  highlighted_data = []
84
- # color_cycler = itertools.cycle(colors) # Moved inside loop if needed per-token
85
  for i, token_id in enumerate(tiktoken_ids):
86
- try: token_text = enc.decode([token_id])
87
- except UnicodeDecodeError:
 
88
  try:
89
- token_bytes = enc.decode_single_token_bytes(token_id)
90
  token_text = f"[Bytes: {token_bytes.hex()}]"
91
  except Exception: token_text = "[Decode Error]"
92
  except Exception as e:
93
- print(f"Unexpected tiktoken decode error: {e}")
94
- token_text = "[Decode Error]"
 
95
  token_label = f"GPT4 Tk {i+1}"
96
  highlighted_data.append((token_text, token_label))
 
97
  token_count = len(tiktoken_ids)
98
- print(f"Tiktoken processing complete. Found {token_count} tokens.")
99
- return highlighted_data, token_count
100
- except ImportError:
101
- print("Error: tiktoken library not found. Please install it: pip install tiktoken")
102
- return [("tiktoken library not installed.", "Error")], 0
103
- except Exception as tiktoken_err:
104
- print(f"Error during tiktoken processing: {tiktoken_err}")
105
- return [(f"Error processing with tiktoken: {str(tiktoken_err)}", "Error")], 0
 
106
 
107
 
108
- def create_llama3_highlight_data(prompt, colors, model_name=LLAMA3_MODEL_NAME):
109
  """Generates data for gr.HighlightedText based on Llama 3 tokenizer."""
 
110
  try:
111
- # Load Llama 3 tokenizer from Hugging Face Hub
112
- print(f"Loading Llama 3 tokenizer: {model_name}")
113
- # Use trust_remote_code=True if required by the specific model revision
114
- tokenizer = AutoTokenizer.from_pretrained(model_name) #, trust_remote_code=True)
115
- print("Llama 3 tokenizer loaded.")
116
-
117
- # Encode the prompt
118
  llama_token_ids = tokenizer.encode(prompt)
119
-
120
  highlighted_data = []
121
- # color_cycler = itertools.cycle(colors) # Moved inside loop if needed per-token
122
-
123
  for i, token_id in enumerate(llama_token_ids):
124
  try:
125
- # Decode individual token.
126
  token_text = tokenizer.decode([token_id])
127
  except Exception as e:
128
- print(f"Unexpected Llama 3 decode error for token {token_id}: {e}")
129
  token_text = "[Decode Error]"
130
 
131
- token_label = f"Llama3 Tk {i+1}" # Clearer label prefix
132
  highlighted_data.append((token_text, token_label))
133
 
134
  token_count = len(llama_token_ids)
135
- print(f"Llama 3 processing complete. Found {token_count} tokens.")
136
- return highlighted_data, token_count
137
-
138
- except ImportError:
139
- print("Error: transformers or sentencepiece library not found. Please install them: pip install transformers sentencepiece")
140
- return [("transformers/sentencepiece library not installed.", "Error")], 0
141
- except OSError as e:
142
- # Handle errors like model not found, network issues, authentication needed
143
- print(f"Error loading Llama 3 tokenizer '{model_name}': {e}")
144
- error_msg = f"Could not load Llama 3 tokenizer '{model_name}'. Check model name and network."
145
- if "authentication" in str(e).lower():
146
- error_msg = f"Authentication required for Llama 3 tokenizer '{model_name}'. Use `huggingface-cli login`."
147
- return [(f"{error_msg} Error: {e}", "Error")], 0
148
- except Exception as llama_err:
149
- print(f"Error during Llama 3 processing: {llama_err}")
150
- import traceback
151
- traceback.print_exc() # Print full traceback for debugging
152
- return [(f"Error processing with Llama 3: {str(llama_err)}", "Error")], 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
 
155
  # --- Main Processing Function ---
156
 
157
- def process_text(prompt: str, model_name: str = "blt-1b"):
 
 
 
 
 
158
  """
159
  Processes the input prompt using ByteLatent, Tiktoken, and Llama 3,
160
  returning visualizations, counts, and status.
161
-
162
- Args:
163
- prompt: The input text string from the Gradio interface.
164
- model_name: The name of the bytelatent model to use.
165
-
166
- Returns:
167
- A tuple containing:
168
- - Matplotlib Figure for the entropy plot (or None).
169
- - List of tuples for bytelatent gr.HighlightedText (or None).
170
- - Integer count of bytelatent patches.
171
- - List of tuples for tiktoken gr.HighlightedText (or None).
172
- - Integer count of tiktoken tokens.
173
- - List of tuples for Llama 3 gr.HighlightedText (or None).
174
- - Integer count of Llama 3 tokens.
175
- - Status/Error message string.
176
  """
 
177
  fig = None
178
- bl_highlighted_data = None
179
- tk_highlighted_data = None
180
- llama_highlighted_data = None
181
- bl_count = 0
182
- tk_count = 0
183
- llama_count = 0
184
- status_message = "Starting processing..."
185
-
186
- # --- 1. Tiktoken Processing (Independent) ---
187
- status_message += "\nProcessing with Tiktoken (gpt-4)..."
188
- tk_highlighted_data, tk_count_calc = create_tiktoken_highlight_data(prompt, VIZ_COLORS)
189
- if tk_highlighted_data and tk_highlighted_data[0][1] == "Error":
190
- status_message += f"\nTiktoken Error: {tk_highlighted_data[0][0]}"
191
- tk_count = 0 # Ensure count is 0 on error
192
- else:
193
- tk_count = tk_count_calc # Assign calculated count
194
- status_message += f"\nTiktoken processing successful ({tk_count} tokens)."
195
-
196
- # --- 2. Llama 3 Processing (Independent) ---
197
- status_message += "\nProcessing with Llama 3 tokenizer..."
198
- llama_highlighted_data, llama_count_calc = create_llama3_highlight_data(prompt, VIZ_COLORS)
199
- if llama_highlighted_data and llama_highlighted_data[0][1] == "Error":
200
- status_message += f"\nLlama 3 Error: {llama_highlighted_data[0][0]}"
201
- llama_count = 0 # Ensure count is 0 on error
202
  else:
203
- llama_count = llama_count_calc # Assign calculated count
204
- status_message += f"\nLlama 3 processing successful ({llama_count} tokens)."
205
-
206
- # --- 3. Bytelatent Processing ---
207
- if BLT_AVAILABLE:
208
- try:
209
- status_message += f"\nLoading Bytelatent entropy model for '{model_name}'..."
210
- # (Bytelatent loading code remains the same)
211
- consolidated_path = os.path.join("hf-weights", model_name)
212
- train_args_path = os.path.join(consolidated_path, "params.json")
213
- if not os.path.exists(train_args_path): raise FileNotFoundError(f"BLT training args not found at {train_args_path}.")
214
- fs = get_fs(train_args_path); train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
215
- bl_tokenizer = train_args.data.tokenizer_args.build(); assert isinstance(bl_tokenizer, BltTokenizer)
216
- patcher_args = train_args.data.patcher_args.model_copy(deep=True); patcher_args.realtime_patching = True
217
- device = "cuda" if torch.cuda.is_available() else "cpu"; print(f"Using BLT device: {device}")
218
- patcher_args.patching_device = device; patcher_args.device = device
219
- entropy_model_dir = os.path.join(consolidated_path, "entropy_model")
220
- if not os.path.exists(entropy_model_dir): raise FileNotFoundError(f"Entropy model directory not found at {entropy_model_dir}.")
221
- patcher_args.entropy_model_checkpoint_dir = entropy_model_dir; bl_patcher = patcher_args.build()
222
- status_message += "\nBytelatent entropy model loaded."
223
-
224
- # --- Processing ---
225
- status_message += "\nRunning Bytelatent entropy model patching..."
226
- print(f"Processing prompt with entropy model: '{prompt}'")
227
- prompt_bytes = prompt.encode('utf-8')
228
- max_bytes = 512 # Define max bytes
229
- if len(prompt_bytes) > max_bytes:
230
- print(f"Warning: Prompt exceeds {max_bytes} bytes ({len(prompt_bytes)}). Truncating for entropy model.")
231
- # Find the byte position that corresponds to the last full character within the limit
232
- # This avoids splitting a multi-byte character
233
- try:
234
- last_char_pos = prompt_bytes[:max_bytes].rfind(b' ') # Simple whitespace split point find, might not be perfect
235
- if last_char_pos == -1: # If no space, truncate hard (less ideal)
236
- prompt_bl = prompt_bytes[:max_bytes].decode('utf-8', errors='ignore')
237
- else:
238
- prompt_bl = prompt_bytes[:last_char_pos].decode('utf-8', errors='ignore')
239
-
240
- except Exception: # Fallback to simple truncation on decode errors
241
- prompt_bl = prompt_bytes[:max_bytes].decode('utf-8', errors='ignore')
242
-
243
- status_message += f"\nWarning: Prompt truncated to approx {len(prompt_bl.encode('utf-8'))} bytes for Bytelatent entropy model."
244
- else:
245
- prompt_bl = prompt
246
-
247
- results = patcher_nocache([prompt_bl], tokenizer=bl_tokenizer, patcher=bl_patcher)
248
 
249
- if not results:
250
- print("Bytelatent entropy processing returned no results.")
251
- status_message += "\nBytelatent entropy model warning: Processing completed, but no results were generated."
252
- bl_highlighted_data = [("No patches generated.", "Info")]
253
- bl_count = 0
254
- else:
255
- batch_patch_lengths, batch_scores, batch_tokens = results
256
- patch_lengths, scores, tokens = batch_patch_lengths[0], batch_scores[0], batch_tokens[0]
257
- # --- Visualization Data Generation ---
258
- try: decoded_output_for_plot = bl_tokenizer.decode(tokens.tolist())
259
- except Exception as decode_err:
260
- print(f"Warning: Error decoding full sequence for plot: {decode_err}")
261
- decoded_output_for_plot = prompt_bl # Use truncated prompt for plot if decode fails
262
-
263
- fig = plot_entropies(patch_lengths, scores, decoded_output_for_plot, threshold=bl_patcher.threshold)
264
- bl_highlighted_data, bl_count_calc = create_bytelatent_highlight_data(bl_tokenizer, patch_lengths, tokens, VIZ_COLORS)
265
- bl_count = bl_count_calc # Assign calculated count
266
-
267
- status_message += f"\nBytelatent entropy model processing and visualization successful ({bl_count} patches)."
268
- print("Bytelatent Entropy model processing and decoding complete.")
269
-
270
- except FileNotFoundError as e:
271
- print(f"Bytelatent Error: {e}")
272
- status_message += f"\nBytelatent FileNotFoundError: {str(e)}"
273
- bl_highlighted_data = [(f"Bytelatent Error: {e}", "Error")]
274
- bl_count = 0
275
- except Exception as e:
276
- print(f"An unexpected Bytelatent error occurred: {e}")
277
- import traceback
278
- traceback.print_exc()
279
- status_message += f"\nBytelatent Unexpected Error: {str(e)}"
280
- bl_highlighted_data = [(f"Bytelatent Error: {e}", "Error")]
281
- bl_count = 0
282
  else:
283
- status_message += "\nBytelatent processing skipped (library not found)."
284
- bl_highlighted_data = [("Bytelatent library not available.", "Error")]
285
- bl_count = 0
286
- fig = None # Ensure fig is None if BLT is skipped
287
 
288
- # Return all generated data and the final status message
289
- return fig, bl_highlighted_data, bl_count, tk_highlighted_data, tk_count, llama_highlighted_data, llama_count, status_message
 
 
 
 
290
 
 
 
 
 
 
 
 
 
291
 
292
  # --- Gradio Interface ---
293
 
294
- # Create color maps for HighlightedText dynamically
295
- MAX_EXPECTED_SEGMENTS = 2000 # Increased max segments further just in case
296
- common_error_map = {"Error": "#FF0000", "Info": "#808080"} # Red for errors, Gray for info
297
-
298
- bytelatent_color_map = {f"BL Patch {i+1}": color for i, color in zip(range(MAX_EXPECTED_SEGMENTS), itertools.cycle(VIZ_COLORS))}
299
- bytelatent_color_map["BL Remainder"] = "#AAAAAA"; bytelatent_color_map.update(common_error_map)
300
 
301
- tiktoken_color_map = {f"GPT4 Tk {i+1}": color for i, color in zip(range(MAX_EXPECTED_SEGMENTS), itertools.cycle(VIZ_COLORS))}
302
- tiktoken_color_map.update(common_error_map)
 
303
 
304
- llama3_color_map = {f"Llama3 Tk {i+1}": color for i, color in zip(range(MAX_EXPECTED_SEGMENTS), itertools.cycle(VIZ_COLORS))}
305
- llama3_color_map.update(common_error_map)
306
-
307
-
308
- with gr.Blocks(theme=gr.themes.Origin()) as iface:
309
- gr.Markdown("# BLT's Entropy-based Patcher vs. Tokenizer Visualisation")
310
- gr.Markdown(
311
- "Enter text to visualize its segmentation according to different methods:\n"
312
- "1. **Byte Latent Transformer (BLT):** Entropy-based patching plot and patched text (_for this space ONLY_ - limited to ~512 bytes).\n"
313
- "2. **Tiktoken (GPT-4):** Text segmented by `cl100k_base` tokens.\n"
314
- f"3. **Llama 3:** Text segmented by the `{LLAMA3_MODEL_NAME}` tokenizer."
315
- )
316
 
317
  with gr.Row():
318
  with gr.Column(scale=1): # Input Column
319
  prompt_input = gr.Textbox(
320
  label="Input Prompt",
321
- value="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
322
  placeholder="Enter text here...",
323
- max_length=512, # Allow even longer input, Bytelatent will truncate
324
  lines=5,
325
- info="For this space ONLY, processing is limited to ~512 bytes."
326
  )
327
  submit_button = gr.Button("Generate Visualizations", variant="primary")
328
- status_output = gr.Textbox(label="Processing Status", interactive=False, lines=7) # Increased lines slightly
329
 
330
  with gr.Column(scale=2): # Output Column
331
-
332
  # --- Bytelatent Output Area ---
333
- with gr.Row(equal_height=False): # Use Row to place title and count together
334
- gr.Markdown("### BLT Entropy Patcher Output (`blt_main_entropy_100m_512w`)")
335
-
336
- bl_count_output = gr.Number(label="Patch Count", value=0, interactive=False, scale=1, step=1) # Added Number output
337
- highlighted_output_bl = gr.HighlightedText(
338
- label="BLT's Entropy-based Patches",
339
- color_map=bytelatent_color_map,
340
- show_legend=False, # Legend can get very long
341
- # show_label=False, # Hide the HighlightedText label as we have the markdown title
342
- show_inline_category=False,
343
- # container=False, # Reduces vertical space slightly
344
- )
345
- plot_output = gr.Plot(label="Entropy vs. Token Index", show_label=True)
 
 
 
 
 
 
 
346
 
347
  # --- Tiktoken Output Area ---
348
- with gr.Row(equal_height=False):
349
- gr.Markdown("### Tiktoken Output (`cl100k_base`)")
350
-
351
- tk_count_output = gr.Number(label="Token Count", value=0, interactive=False, scale=1, step=1) # Added Number output
352
- highlighted_output_tk = gr.HighlightedText(
353
- label="Tiktoken Segmented Text",
354
- color_map=tiktoken_color_map,
355
- show_legend=False,
356
- show_inline_category=False,
357
- # show_label=False,
358
- # container=False,
359
- )
 
 
 
 
360
 
361
  # --- Llama 3 Output Area ---
362
- with gr.Row(equal_height=False):
363
- gr.Markdown(f"### Llama 3 Output (`{LLAMA3_MODEL_NAME}`)")
364
-
365
- llama_count_output = gr.Number(label="Token Count", value=0, interactive=False, scale=1, step=1) # Added Number output
366
- highlighted_output_llama = gr.HighlightedText(
367
- label="Llama 3 Segmented Text",
368
- color_map=llama3_color_map,
369
- show_legend=False,
370
- show_inline_category=False,
371
- # show_label=False,
372
- # container=False,
373
- )
 
 
 
 
 
374
 
375
  # Define the action for the button click
376
  submit_button.click(
377
  fn=process_text,
378
  inputs=prompt_input,
379
- # Ensure order matches the 8 return values of process_text
380
  outputs=[
381
- plot_output, # fig
382
- highlighted_output_bl, # bl_highlighted_data
383
- bl_count_output, # bl_count <-- New
384
- highlighted_output_tk, # tk_highlighted_data
385
- tk_count_output, # tk_count <-- New
386
- highlighted_output_llama,# llama_highlighted_data
387
- llama_count_output, # llama_count <-- New
388
- status_output # status_message
389
- ]
 
 
 
 
390
  )
391
 
392
  # --- Launch the Gradio App ---
393
  if __name__ == "__main__":
394
- print("Checking required libraries...")
395
- try:
396
- import tiktoken
397
- print("- tiktoken found.")
398
- except ImportError:
399
- print("WARNING: 'tiktoken' not found. GPT-4 visualization will fail. Install with: pip install tiktoken")
400
- try:
401
- import transformers
402
- import sentencepiece
403
- print("- transformers found.")
404
- print("- sentencepiece found.")
405
- except ImportError:
406
- print("WARNING: 'transformers' or 'sentencepiece' not found. Llama 3 visualization will fail. Install with: pip install transformers sentencepiece")
407
-
408
- if BLT_AVAILABLE:
409
- print("- Bytelatent libraries found.")
410
- # Ensure bytelatent model is present only if library is available
411
- try:
412
- print("Ensuring Bytelatent model 'blt-1b' weights are present...")
413
- ensure_present(["blt-1b"])
414
- print("Bytelatent model check complete.")
415
- except Exception as blt_dl_err:
416
- print(f"WARNING: Failed to ensure Bytelatent model presence: {blt_dl_err}")
417
- else:
418
- print("INFO: Bytelatent libraries not found, skipping related functionality.")
419
-
420
- print(f"Attempting to use Llama 3 Tokenizer: {LLAMA3_MODEL_NAME}. Ensure you have access (e.g., via `huggingface-cli login` if needed).")
421
- print("Launching Gradio interface...")
422
  iface.launch()
 
3
  import torch
4
  import itertools # For color cycling
5
  import tiktoken # For GPT-4 tokenizer
6
+ from transformers import AutoTokenizer, HfArgumentParser # For Llama3 tokenizer & args potentially
7
+ import traceback # For detailed error logging
8
+ import logging # For better logging practices
9
+ from typing import Optional, Tuple, List, Dict, Any
10
+ import matplotlib.figure # For type hinting
11
+ import matplotlib.pyplot as plt
12
+
13
+ # --- Configuration ---
14
+
15
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
+
17
+ class Config:
18
+ # Visualization
19
+ VIZ_COLORS: List[str] = [
20
+ "#a6cee3", "#1f78b4", "#b2df8a", "#33a02c", "#fb9a99", "#e31a1c",
21
+ "#fdbf6f", "#ff7f00", "#cab2d6", "#6a3d9a", "#ffff99", "#b15928"
22
+ ]
23
+ MAX_EXPECTED_SEGMENTS: int = 1 # Max segments for color map generation
24
+
25
+ # Model/Tokenizer Names
26
+ LLAMA3_MODEL_NAME: str = "meta-llama/Meta-Llama-3-8B" # Or choose another variant like Instruct
27
+ TIKTOKEN_ENCODING_NAME: str = "cl100k_base"
28
+ BLT_MODEL_NAME: str = "blt-1b" # Default Bytelatent model
29
+
30
+ # Bytelatent Specific
31
+ BLT_WEIGHTS_DIR: str = "hf-weights"
32
+ BLT_MAX_BYTES_FOR_DEMO: int = 512 # Limit for this specific demo's entropy model
33
+
34
+ # Gradio
35
+ DEFAULT_PROMPT: str = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
36
+ GRADIO_THEME = gr.themes.Origin()
37
+ GRADIO_TITLE: str = "BLT's Entropy-based Patcher vs. Tokenizer Visualisation"
38
+ GRADIO_DESC: str = (
39
+ "Enter text to visualize its segmentation according to different methods:\n"
40
+ f"1. **Byte Latent Transformer (BLT):** Entropy-based patching plot and patched text (_for this space ONLY_ - limited to ~{BLT_MAX_BYTES_FOR_DEMO} bytes using `blt_main_entropy_100m_512w`).\n"
41
+ f"2. **Tiktoken (GPT-4):** Text segmented by `{TIKTOKEN_ENCODING_NAME}` tokens.\n"
42
+ f"3. **Llama 3:** Text segmented by the `{LLAMA3_MODEL_NAME}` tokenizer."
43
+ )
44
 
45
+ # --- Bytelatent Processor ---
46
+
47
+ # Attempt to import Bytelatent libraries
48
  try:
49
  from bytelatent.data.file_util import get_fs
50
  from bytelatent.generate_patcher import patcher_nocache
51
  from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
52
  from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies
53
  from bytelatent.args import TrainArgs
54
+ from download_blt_weights import main as ensure_present # Assuming this downloads weights
55
+ _BLT_AVAILABLE = True
56
+ logging.info("Bytelatent libraries found.")
57
  except ImportError as e:
58
+ logging.warning(f"Bytelatent libraries not found. Bytelatent functionality will be disabled. Error: {e}")
59
+ _BLT_AVAILABLE = False
60
  # Define dummy classes/functions if BLT is not available to avoid NameErrors later
61
  class BltTokenizer: pass
62
  class TrainArgs: pass
63
  def patcher_nocache(*args, **kwargs): return None
64
  def plot_entropies(*args, **kwargs): return None
65
  def ensure_present(*args, **kwargs): pass
66
+ matplotlib = None # No plotting if BLT isn't there
67
+
68
+ class BytelatentProcessor:
69
+ """Handles loading and running the Bytelatent entropy model."""
70
+ def __init__(self, model_name: str, weights_dir: str):
71
+ self.model_name = model_name
72
+ self.weights_dir = weights_dir
73
+ self.is_available: bool = False
74
+ self.tokenizer: Optional[BltTokenizer] = None
75
+ self.patcher: Optional[Any] = None # Type depends on bytelatent implementation
76
+ self.device: str = "cuda" if torch.cuda.is_available() else "cpu"
77
+
78
+ if _BLT_AVAILABLE:
79
+ try:
80
+ # 1. Ensure weights are present
81
+ logging.info(f"Ensuring Bytelatent model '{model_name}' weights are present...")
82
+ ensure_present([model_name]) # Call the download script
83
+ logging.info("Bytelatent model check complete.")
84
+
85
+ # 2. Load Bytelatent model components
86
+ consolidated_path = os.path.join(self.weights_dir, model_name)
87
+ train_args_path = os.path.join(consolidated_path, "params.json")
88
+ entropy_model_dir = os.path.join(consolidated_path, "entropy_model")
89
+
90
+ if not os.path.exists(train_args_path):
91
+ raise FileNotFoundError(f"BLT training args not found at {train_args_path}.")
92
+ if not os.path.exists(entropy_model_dir):
93
+ raise FileNotFoundError(f"BLT Entropy model directory not found at {entropy_model_dir}.")
94
+
95
+ fs = get_fs(train_args_path)
96
+ train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
97
+
98
+ self.tokenizer = train_args.data.tokenizer_args.build()
99
+ assert isinstance(self.tokenizer, BltTokenizer), "Failed to build Bytelatent Tokenizer"
100
+
101
+ patcher_args = train_args.data.patcher_args.model_copy(deep=True)
102
+ patcher_args.realtime_patching = True
103
+ patcher_args.patching_device = self.device
104
+ patcher_args.device = self.device
105
+ patcher_args.entropy_model_checkpoint_dir = entropy_model_dir
106
+ self.patcher = patcher_args.build()
107
+
108
+ self.is_available = True
109
+ logging.info(f"Bytelatent processor for '{model_name}' loaded successfully on device '{self.device}'.")
110
+
111
+ except FileNotFoundError as e:
112
+ logging.error(f"Bytelatent setup failed: Required file/directory not found. {e}")
113
+ except Exception as e:
114
+ logging.error(f"An unexpected error occurred during Bytelatent setup: {e}")
115
+ logging.error(traceback.format_exc())
116
+ else:
117
+ logging.warning("Skipping Bytelatent setup as libraries are unavailable.")
118
+
119
+ def _create_highlight_data(self, patch_lengths: torch.Tensor, tokens: torch.Tensor) -> Tuple[List[Tuple[str, str]], int]:
120
+ """Generates data for gr.HighlightedText based on bytelatent patches."""
121
+ if not self.is_available or self.tokenizer is None or patch_lengths.numel() == 0:
122
+ return [("Bytelatent processing failed or produced no patches.", "Error")], 0
123
+
124
+ patch_lengths_list = patch_lengths.tolist()
125
+ all_token_ids = tokens.tolist()
126
+ highlighted_data = []
127
+ current_token_index = 0
128
+ patch_count = 0
129
+
130
+ for i, length in enumerate(patch_lengths_list):
131
+ if length <= 0: continue
132
+ patch_token_ids = all_token_ids[current_token_index : current_token_index + length]
133
+ if not patch_token_ids: continue
134
+
135
+ try:
136
+ patch_text = self.tokenizer.decode(patch_token_ids)
137
+ except Exception as decode_err:
138
+ logging.warning(f"Bytelatent patch decoding failed: {decode_err}")
139
+ patch_text = f"[Decode Error: {len(patch_token_ids)} tokens]"
140
+
141
+ patch_label = f"BL Patch {i+1}"
142
+ highlighted_data.append((patch_text, patch_label))
143
+ patch_count += 1
144
+ current_token_index += length
145
+
146
+ # Handle remainder tokens if any
147
+ if current_token_index < len(all_token_ids):
148
+ remaining_tokens = all_token_ids[current_token_index:]
149
+ try:
150
+ remaining_text = self.tokenizer.decode(remaining_tokens)
151
+ label = "BL Remainder"
152
+ except Exception:
153
+ remaining_text = f"[Decode Error: {len(remaining_tokens)} remaining tokens]"
154
+ label = "Error"
155
+ highlighted_data.append((remaining_text, label))
156
+ logging.warning(f"Bytelatent token mismatch. Consumed {current_token_index}, total {len(all_token_ids)}. Remainder added.")
157
+
158
+ return highlighted_data, patch_count
159
+
160
+ def process(self, prompt: str, max_bytes: int) -> Tuple[Optional[matplotlib.figure.Figure], List[Tuple[str, str]], int, str]:
161
+ """Processes the prompt using the loaded Bytelatent model."""
162
+ status = ""
163
+ if not self.is_available or self.tokenizer is None or self.patcher is None:
164
+ status = "Bytelatent processor not available."
165
+ return None, [("Bytelatent not available.", "Error")], 0, status
166
+
167
+ # Truncate prompt if necessary for this demo's model
168
+ prompt_bytes = prompt.encode('utf-8')
169
+ prompt_bl = prompt
170
+ if len(prompt_bytes) > max_bytes:
171
+ try:
172
+ # Find last full character within limit (simple space split fallback)
173
+ try:
174
+ prompt_bl = prompt_bytes[:max_bytes].decode('utf-8', errors='strict')
175
+ # If successful, find last space to avoid cutting mid-word visually
176
+ last_space = prompt_bl.rfind(' ')
177
+ if last_space != -1:
178
+ prompt_bl = prompt_bl[:last_space]
179
+ except UnicodeDecodeError:
180
+ # If strict fails, find last valid byte sequence start before max_bytes
181
+ i = max_bytes
182
+ while i > 0:
183
+ try:
184
+ prompt_bytes[:i].decode('utf-8', errors='strict')
185
+ break # Found valid end point
186
+ except UnicodeDecodeError:
187
+ i -= 1
188
+ prompt_bl = prompt_bytes[:i].decode('utf-8', errors='ignore') # Decode ignoring errors now
189
+
190
+
191
+ trunc_len = len(prompt_bl.encode('utf-8'))
192
+ status = f"Warning: Prompt truncated to {trunc_len} bytes for Bytelatent entropy model.\n"
193
+ logging.warning(status.strip())
194
+ except Exception as trunc_err:
195
+ # Fallback if complex truncation fails
196
+ prompt_bl = prompt_bytes[:max_bytes].decode('utf-8', errors='ignore')
197
+ trunc_len = len(prompt_bl.encode('utf-8'))
198
+ status = f"Warning: Prompt aggressively truncated to ~{trunc_len} bytes due to encoding issue. Error: {trunc_err}\n"
199
+ logging.warning(status.strip())
200
+
201
+
202
+ # Run Bytelatent patching
203
+ try:
204
+ logging.info(f"Running Bytelatent entropy model patching on {len(prompt_bl.encode('utf-8'))} bytes...")
205
+ results = patcher_nocache([prompt_bl], tokenizer=self.tokenizer, patcher=self.patcher)
206
+ status += "Bytelatent patching executed.\n"
207
+
208
+ if not results:
209
+ logging.warning("Bytelatent entropy processing returned no results.")
210
+ status += "Warning: Bytelatent generated no patches."
211
+ return None, [("No patches generated by Bytelatent.", "Info")], 0, status
212
+
213
+ batch_patch_lengths, batch_scores, batch_tokens = results
214
+ patch_lengths, scores, tokens = batch_patch_lengths[0], batch_scores[0], batch_tokens[0]
215
+
216
+ # Create highlighted text data
217
+ highlighted_data, patch_count = self._create_highlight_data(patch_lengths, tokens)
218
+
219
+ # Create plot
220
+ fig = None
221
+ if plot_entropies is not None: # Check if plotting function is available
222
+ try:
223
+ # Use the potentially truncated prompt_bl for the plot text axis if full decode fails
224
+ decoded_output_for_plot = self.tokenizer.decode(tokens.tolist())
225
+ except Exception as decode_err:
226
+ logging.warning(f"Error decoding full BLT token sequence for plot: {decode_err}. Using (truncated) input prompt for plot axis.")
227
+ decoded_output_for_plot = prompt_bl
228
+
229
+ fig = plot_entropies(patch_lengths, scores, decoded_output_for_plot, threshold=self.patcher.threshold)
230
+ status += f"Bytelatent plot generated. Found {patch_count} patches.\n"
231
+ else:
232
+ status += "Plotting unavailable.\n"
233
+
234
+ logging.info(f"Bytelatent processing complete. Patches: {patch_count}")
235
+ return fig, highlighted_data, patch_count, status.strip()
236
+
237
+ except Exception as e:
238
+ logging.error(f"An error occurred during Bytelatent processing: {e}")
239
+ logging.error(traceback.format_exc())
240
+ status += f"Error during Bytelatent processing: {e}"
241
+ return None, [(f"Bytelatent Error: {e}", "Error")], 0, status.strip()
242
+
243
 
244
+ # --- Tokenizer Helpers ---
245
 
246
+ def create_tiktoken_highlight_data(prompt: str, encoding: tiktoken.Encoding) -> Tuple[List[Tuple[str, str]], int, str]:
247
+ """Generates data for gr.HighlightedText based on tiktoken."""
248
+ status = "Processing with Tiktoken...\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  try:
250
+ tiktoken_ids = encoding.encode(prompt)
 
251
  highlighted_data = []
 
252
  for i, token_id in enumerate(tiktoken_ids):
253
+ try:
254
+ token_text = encoding.decode([token_id])
255
+ except (UnicodeDecodeError, TypeError): # Handle bytes that don't form valid unicode
256
  try:
257
+ token_bytes = encoding.decode_single_token_bytes(token_id)
258
  token_text = f"[Bytes: {token_bytes.hex()}]"
259
  except Exception: token_text = "[Decode Error]"
260
  except Exception as e:
261
+ logging.warning(f"Unexpected tiktoken decode error for token ID {token_id}: {e}")
262
+ token_text = "[Decode Error]"
263
+
264
  token_label = f"GPT4 Tk {i+1}"
265
  highlighted_data.append((token_text, token_label))
266
+
267
  token_count = len(tiktoken_ids)
268
+ status += f"Tiktoken processing successful ({token_count} tokens)."
269
+ logging.info(f"Tiktoken processing complete. Found {token_count} tokens.")
270
+ return highlighted_data, token_count, status.strip()
271
+
272
+ except Exception as e:
273
+ logging.error(f"Error during tiktoken processing: {e}")
274
+ logging.error(traceback.format_exc())
275
+ status += f"Error during Tiktoken processing: {e}"
276
+ return [(f"Error processing with tiktoken: {e}", "Error")], 0, status.strip()
277
 
278
 
279
+ def create_llama3_highlight_data(prompt: str, tokenizer: AutoTokenizer) -> Tuple[List[Tuple[str, str]], int, str]:
280
  """Generates data for gr.HighlightedText based on Llama 3 tokenizer."""
281
+ status = f"Processing with Llama 3 ({tokenizer.name_or_path})...\n"
282
  try:
 
 
 
 
 
 
 
283
  llama_token_ids = tokenizer.encode(prompt)
 
284
  highlighted_data = []
 
 
285
  for i, token_id in enumerate(llama_token_ids):
286
  try:
287
+ # Decode individual token. Add special handling if needed for specific tokenizers.
288
  token_text = tokenizer.decode([token_id])
289
  except Exception as e:
290
+ logging.warning(f"Unexpected Llama 3 decode error for token ID {token_id}: {e}")
291
  token_text = "[Decode Error]"
292
 
293
+ token_label = f"Llama3 Tk {i+1}"
294
  highlighted_data.append((token_text, token_label))
295
 
296
  token_count = len(llama_token_ids)
297
+ status += f"Llama 3 processing successful ({token_count} tokens)."
298
+ logging.info(f"Llama 3 processing complete. Found {token_count} tokens.")
299
+ return highlighted_data, token_count, status.strip()
300
+
301
+ except Exception as e:
302
+ logging.error(f"Error during Llama 3 processing: {e}")
303
+ logging.error(traceback.format_exc())
304
+ status += f"Error during Llama 3 processing: {e}"
305
+ return [(f"Error processing with Llama 3: {e}", "Error")], 0, status.strip()
306
+
307
+ # --- Global Initializations ---
308
+
309
+ # Initialize Bytelatent Processor (loads model if available)
310
+ blt_processor = BytelatentProcessor(Config.BLT_MODEL_NAME, Config.BLT_WEIGHTS_DIR)
311
+
312
+ # Load Tiktoken Encoding
313
+ try:
314
+ tiktoken_encoding = tiktoken.get_encoding(Config.TIKTOKEN_ENCODING_NAME)
315
+ logging.info(f"Tiktoken encoding '{Config.TIKTOKEN_ENCODING_NAME}' loaded.")
316
+ tiktoken_available = True
317
+ except Exception as e:
318
+ logging.error(f"Failed to load Tiktoken encoding '{Config.TIKTOKEN_ENCODING_NAME}': {e}")
319
+ tiktoken_encoding = None
320
+ tiktoken_available = False
321
+
322
+ # Load Llama 3 Tokenizer
323
+ try:
324
+ # Use trust_remote_code=True if required by the specific model revision
325
+ llama_tokenizer = AutoTokenizer.from_pretrained(Config.LLAMA3_MODEL_NAME) #, trust_remote_code=True)
326
+ logging.info(f"Llama 3 tokenizer '{Config.LLAMA3_MODEL_NAME}' loaded.")
327
+ llama_available = True
328
+ except ImportError:
329
+ logging.error("Transformers or SentencePiece library not found. Llama 3 functionality disabled. Install with: pip install transformers sentencepiece")
330
+ llama_tokenizer = None
331
+ llama_available = False
332
+ except OSError as e:
333
+ logging.error(f"Error loading Llama 3 tokenizer '{Config.LLAMA3_MODEL_NAME}': {e}")
334
+ error_msg = f"Could not load Llama 3 tokenizer '{Config.LLAMA3_MODEL_NAME}'. Check model name, network, and authentication (use `huggingface-cli login` if needed)."
335
+ logging.error(error_msg)
336
+ llama_tokenizer = None
337
+ llama_available = False
338
+ except Exception as e:
339
+ logging.error(f"An unexpected error occurred loading Llama 3 tokenizer: {e}")
340
+ logging.error(traceback.format_exc())
341
+ llama_tokenizer = None
342
+ llama_available = False
343
 
344
 
345
  # --- Main Processing Function ---
346
 
347
+ def process_text(prompt: str) -> Tuple[
348
+ Optional[matplotlib.figure.Figure], List[Tuple[str, str]], int, # BLT
349
+ List[Tuple[str, str]], int, # Tiktoken
350
+ List[Tuple[str, str]], int, # Llama 3
351
+ str # Status
352
+ ]:
353
  """
354
  Processes the input prompt using ByteLatent, Tiktoken, and Llama 3,
355
  returning visualizations, counts, and status.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  """
357
+ status_messages = ["Processing started..."]
358
  fig = None
359
+ bl_highlighted_data, bl_count = [("Bytelatent not available.", "Error")], 0
360
+ tk_highlighted_data, tk_count = [("Tiktoken not available.", "Error")], 0
361
+ llama_highlighted_data, llama_count = [("Llama 3 not available.", "Error")], 0
362
+
363
+ # 1. Bytelatent Processing
364
+ if blt_processor.is_available:
365
+ fig, bl_highlighted_data, bl_count, bl_status = blt_processor.process(prompt, Config.BLT_MAX_BYTES_FOR_DEMO)
366
+ status_messages.append(f"Bytelatent Status:\n{bl_status}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  else:
368
+ status_messages.append("Bytelatent Status: Skipped (processor unavailable).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
+ # 2. Tiktoken Processing
371
+ if tiktoken_available and tiktoken_encoding:
372
+ tk_highlighted_data, tk_count, tk_status = create_tiktoken_highlight_data(prompt, tiktoken_encoding)
373
+ status_messages.append(f"Tiktoken Status:\n{tk_status}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  else:
375
+ status_messages.append("Tiktoken Status: Skipped (tokenizer unavailable).")
 
 
 
376
 
377
+ # 3. Llama 3 Processing
378
+ if llama_available and llama_tokenizer:
379
+ llama_highlighted_data, llama_count, llama_status = create_llama3_highlight_data(prompt, llama_tokenizer)
380
+ status_messages.append(f"Llama 3 Status:\n{llama_status}")
381
+ else:
382
+ status_messages.append("Llama 3 Status: Skipped (tokenizer unavailable).")
383
 
384
+ final_status = "\n---\n".join(status_messages)
385
+ if fig is not None and matplotlib is not None:
386
+ try:
387
+ plt.close(fig) # Close the specific figure
388
+ logging.debug("Closed Matplotlib figure.")
389
+ except Exception as close_err:
390
+ logging.warning(f"Could not close Matplotlib figure: {close_err}")
391
+ return fig, bl_highlighted_data, bl_count, tk_highlighted_data, tk_count, llama_highlighted_data, llama_count, final_status
392
 
393
  # --- Gradio Interface ---
394
 
395
+ def create_color_map(label_prefix: str, colors: List[str], max_segments: int) -> Dict[str, str]:
396
+ """Generates a color map dictionary for Gradio HighlightedText."""
397
+ color_cycler = itertools.cycle(colors)
398
+ color_map = {f"{label_prefix} {i+1}": next(color_cycler) for i in range(max_segments)}
399
+ color_map.update({"Error": "#FF0000", "Info": "#808080", "BL Remainder": "#AAAAAA"}) # Common labels
400
+ return color_map
401
 
402
+ bytelatent_color_map = create_color_map("BL Patch", Config.VIZ_COLORS, Config.MAX_EXPECTED_SEGMENTS)
403
+ tiktoken_color_map = create_color_map("GPT4 Tk", Config.VIZ_COLORS, Config.MAX_EXPECTED_SEGMENTS)
404
+ llama3_color_map = create_color_map("Llama3 Tk", Config.VIZ_COLORS, Config.MAX_EXPECTED_SEGMENTS)
405
 
406
+ with gr.Blocks(theme=Config.GRADIO_THEME) as iface:
407
+ gr.Markdown(f"# {Config.GRADIO_TITLE}")
408
+ gr.Markdown(Config.GRADIO_DESC)
 
 
 
 
 
 
 
 
 
409
 
410
  with gr.Row():
411
  with gr.Column(scale=1): # Input Column
412
  prompt_input = gr.Textbox(
413
  label="Input Prompt",
414
+ value=Config.DEFAULT_PROMPT,
415
  placeholder="Enter text here...",
416
+ # Max length is for UI input; Bytelatent truncation happens in backend
417
  lines=5,
418
+ info=f"Note: Bytelatent processing is limited to ~{Config.BLT_MAX_BYTES_FOR_DEMO} bytes for this demo."
419
  )
420
  submit_button = gr.Button("Generate Visualizations", variant="primary")
421
+ status_output = gr.Textbox(label="Processing Status", interactive=False, lines=10) # More space for detailed status
422
 
423
  with gr.Column(scale=2): # Output Column
 
424
  # --- Bytelatent Output Area ---
425
+ if blt_processor.is_available: # Only show BLT section if it loaded
426
+ with gr.Accordion("BLT Entropy Patcher Output (`blt_main_entropy_100m_512w`)", open=True):
427
+ with gr.Row():
428
+ bl_count_output = gr.Number(label="Patch Count", value=0, interactive=False, step=1, scale=0)
429
+ highlighted_output_bl = gr.HighlightedText(
430
+ label="BLT Patches",
431
+ color_map=bytelatent_color_map,
432
+ show_legend=False,
433
+ show_inline_category=False,
434
+ container=False
435
+ )
436
+ plot_output = gr.Plot(label="Entropy vs. Token Index")
437
+ else:
438
+ gr.Markdown(f"### Bytelatent Output (`{Config.BLT_MODEL_NAME}`)")
439
+ gr.Markdown("_(Bytelatent processor failed to load or libraries are missing. Output unavailable.)_")
440
+ # Define dummy outputs if BLT is unavailable so the `outputs` list doesn't break
441
+ highlighted_output_bl = gr.HighlightedText(value=[("BLT Unavailable", "Error")], label="BLT Patches", visible=False)
442
+ bl_count_output = gr.Number(value=0, label="Patch Count", visible=False)
443
+ plot_output = gr.Plot(label="Entropy Plot", visible=False)
444
+
445
 
446
  # --- Tiktoken Output Area ---
447
+ if tiktoken_available: # Only show Tiktoken section if it loaded
448
+ with gr.Accordion(f"Tiktoken Output (`{Config.TIKTOKEN_ENCODING_NAME}`)", open=True):
449
+ with gr.Row():
450
+ tk_count_output = gr.Number(label="Token Count", value=0, interactive=False, step=1, scale=0)
451
+ highlighted_output_tk = gr.HighlightedText(
452
+ label="Tiktoken Segments",
453
+ color_map=tiktoken_color_map,
454
+ show_legend=False,
455
+ show_inline_category=False,
456
+ container=False
457
+ )
458
+ else:
459
+ gr.Markdown(f"### Tiktoken Output (`{Config.TIKTOKEN_ENCODING_NAME}`)")
460
+ gr.Markdown("_(Tiktoken failed to load. Output unavailable.)_")
461
+ highlighted_output_tk = gr.HighlightedText(value=[("Tiktoken Unavailable", "Error")], label="Tiktoken Segments", visible=False)
462
+ tk_count_output = gr.Number(value=0, label="Token Count", visible=False)
463
 
464
  # --- Llama 3 Output Area ---
465
+ if llama_available: # Only show Llama section if it loaded
466
+ with gr.Accordion(f"Llama 3 Output (`{Config.LLAMA3_MODEL_NAME}`)", open=True):
467
+ with gr.Row():
468
+ llama_count_output = gr.Number(label="Token Count", value=0, interactive=False, step=1, scale=0)
469
+ highlighted_output_llama = gr.HighlightedText(
470
+ label="Llama 3 Segments",
471
+ color_map=llama3_color_map,
472
+ show_legend=False,
473
+ show_inline_category=False,
474
+ container=False
475
+ )
476
+ else:
477
+ gr.Markdown(f"### Llama 3 Output (`{Config.LLAMA3_MODEL_NAME}`)")
478
+ gr.Markdown("_(Llama 3 tokenizer failed to load. Output unavailable.)_")
479
+ highlighted_output_llama = gr.HighlightedText(value=[("Llama 3 Unavailable", "Error")], label="Llama 3 Segments", visible=False)
480
+ llama_count_output = gr.Number(value=0, label="Token Count", visible=False)
481
+
482
 
483
  # Define the action for the button click
484
  submit_button.click(
485
  fn=process_text,
486
  inputs=prompt_input,
487
+ # Ensure order matches the return values of process_text
488
  outputs=[
489
+ # Bytelatent outputs (even if dummy/hidden)
490
+ plot_output,
491
+ highlighted_output_bl,
492
+ bl_count_output,
493
+ # Tiktoken outputs (even if dummy/hidden)
494
+ highlighted_output_tk,
495
+ tk_count_output,
496
+ # Llama 3 outputs (even if dummy/hidden)
497
+ highlighted_output_llama,
498
+ llama_count_output,
499
+ # Status output
500
+ status_output
501
+ ]
502
  )
503
 
504
  # --- Launch the Gradio App ---
505
  if __name__ == "__main__":
506
+ logging.info("-----------------------------------------")
507
+ logging.info("Starting Gradio App...")
508
+ logging.info(f"Bytelatent Available: {blt_processor.is_available}")
509
+ logging.info(f"Tiktoken Available: {tiktoken_available}")
510
+ logging.info(f"Llama 3 Tokenizer Available: {llama_available}")
511
+ logging.info("-----------------------------------------")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  iface.launch()