Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,446 Bytes
570eaa9 2af55e5 570eaa9 2af55e5 570eaa9 2af55e5 570eaa9 2af55e5 570eaa9 2af55e5 570eaa9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
import gradio as gr
import torch
from bytelatent.data.file_util import get_fs
from bytelatent.generate_patcher import patcher_nocache
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies
from bytelatent.args import TrainArgs
from download_blt_weights import main as ensure_present
# --- Global Setup (Consider loading models outside if necessary) ---
# Kept inside the function for simplicity as before.
def process_text(prompt: str, model_name: str = "blt-1b"):
"""
Processes the input prompt using the ByteLatent model and returns decoded characters.
Args:
prompt: The input text string from the Gradio interface.
model_name: The name of the model to use.
Returns:
A string containing the decoded characters after processing, or an error message.
"""
try:
# --- Model and Tokenizer Loading ---
consolidated_path = os.path.join("hf-weights", model_name)
train_args_path = os.path.join(consolidated_path, "params.json")
if not os.path.exists(train_args_path):
raise FileNotFoundError(f"Training args not found at {train_args_path}. "
f"Ensure model '{model_name}' is downloaded/available.")
fs = get_fs(train_args_path)
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
tokenizer = train_args.data.tokenizer_args.build()
assert isinstance(tokenizer, BltTokenizer)
patcher_args = train_args.data.patcher_args.model_copy(deep=True)
patcher_args.realtime_patching = True
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
patcher_args.patching_device = device
patcher_args.device = device
print("Loading entropy model and patcher...")
entropy_model_dir = os.path.join(consolidated_path, "entropy_model")
if not os.path.exists(entropy_model_dir):
raise FileNotFoundError(f"Entropy model directory not found at {entropy_model_dir}.")
patcher_args.entropy_model_checkpoint_dir = entropy_model_dir
patcher = patcher_args.build()
# --- End Loading ---
# --- Processing ---
prompts = [prompt]
print(f"Processing prompt: '{prompt}'")
results = patcher_nocache(
prompts, tokenizer=tokenizer, patcher=patcher
)
if not results:
print("Processing returned no results.")
return "Processing completed, but no results were generated." # Return info message
batch_patch_lengths, batch_scores, batch_tokens = results
# Decode the first (and only) result in the batch
decoded_chars_list = [tokenizer.decode(row_tokens.tolist()) for row_tokens in batch_tokens]
fig = None
if decoded_chars_list:
decoded_output = decoded_chars_list[0]
fig = plot_entropies(
batch_patch_lengths[0],
batch_scores[0],
decoded_output,
threshold=patcher.threshold
)
print("Processing and decoding complete.")
# --- End Processing ---
return fig
except FileNotFoundError as e:
print(f"Error: {e}")
# raise gr.Error(str(e)) # Display specific error in Gradio UI
return f"Error: {str(e)}" # Return error as text output
except Exception as e:
print(f"An unexpected error occurred: {e}")
import traceback
traceback.print_exc()
# raise gr.Error(f"An error occurred during processing: {e}")
return f"An unexpected error occurred: {e}" # Return error as text output
iface = gr.Interface(
fn=process_text,
inputs=gr.Textbox(
label="Input Prompt",
placeholder="Enter your text here..."
),
outputs=gr.Plot(label="Entropy Plot"),
title="ByteLatent Text Processor",
description="Enter text to process it with the ByteLatent model ('blt-1b' by default). The decoded output will be shown.",
allow_flagging="never",
)
with gr.Blocks() as iface:
gr.Markdown("# ByteLatent Entropy Visualizer") # Title
gr.Markdown(
"Process any prompt (limited to 512 bytes) with the 100M entropy patcher model "
"and visualize the token entropies plot below.<br><br>" # Updated description
"NOTE: this implementation differs slightly by excluding local attention so we limit "
"the characters limit to 512 to avoid any deviation.",
line_breaks=True
)
with gr.Column():
prompt_input = gr.Textbox(
label="Input Prompt",
value="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
placeholder="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
max_length=512
)
submit_button = gr.Button("Generate Plot") # Add button
plot_output = gr.Plot(label="Entropy w Threshold") # Output component
# Define the action when the button is clicked
submit_button.click(
fn=process_text,
inputs=prompt_input, # Input component(s)
outputs=plot_output # Output component(s)
)
# --- Launch the Gradio App ---
if __name__ == "__main__":
ensure_present(["blt-1b"])
iface.launch()
|