File size: 5,446 Bytes
570eaa9
 
 
 
 
 
 
2af55e5
570eaa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2af55e5
 
 
 
 
 
 
 
 
570eaa9
 
 
 
2af55e5
570eaa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2af55e5
570eaa9
 
 
 
 
2af55e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570eaa9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import gradio as gr
import torch

from bytelatent.data.file_util import get_fs
from bytelatent.generate_patcher import patcher_nocache
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies
from bytelatent.args import TrainArgs
from download_blt_weights import main as ensure_present

# --- Global Setup (Consider loading models outside if necessary) ---
# Kept inside the function for simplicity as before.

def process_text(prompt: str, model_name: str = "blt-1b"):
    """
    Processes the input prompt using the ByteLatent model and returns decoded characters.

    Args:
        prompt: The input text string from the Gradio interface.
        model_name: The name of the model to use.

    Returns:
        A string containing the decoded characters after processing, or an error message.
    """
    try:
        # --- Model and Tokenizer Loading ---
        consolidated_path = os.path.join("hf-weights", model_name)
        train_args_path = os.path.join(consolidated_path, "params.json")

        if not os.path.exists(train_args_path):
             raise FileNotFoundError(f"Training args not found at {train_args_path}. "
                                     f"Ensure model '{model_name}' is downloaded/available.")

        fs = get_fs(train_args_path)
        train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))

        tokenizer = train_args.data.tokenizer_args.build()
        assert isinstance(tokenizer, BltTokenizer)

        patcher_args = train_args.data.patcher_args.model_copy(deep=True)
        patcher_args.realtime_patching = True
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {device}")
        patcher_args.patching_device = device
        patcher_args.device = device

        print("Loading entropy model and patcher...")
        entropy_model_dir = os.path.join(consolidated_path, "entropy_model")
        if not os.path.exists(entropy_model_dir):
             raise FileNotFoundError(f"Entropy model directory not found at {entropy_model_dir}.")

        patcher_args.entropy_model_checkpoint_dir = entropy_model_dir
        patcher = patcher_args.build()
        # --- End Loading ---

        # --- Processing ---
        prompts = [prompt]
        print(f"Processing prompt: '{prompt}'")
        results = patcher_nocache(
            prompts, tokenizer=tokenizer, patcher=patcher
        )

        if not results:
            print("Processing returned no results.")
            return "Processing completed, but no results were generated." # Return info message

        batch_patch_lengths, batch_scores, batch_tokens = results
        # Decode the first (and only) result in the batch
        decoded_chars_list = [tokenizer.decode(row_tokens.tolist()) for row_tokens in batch_tokens]
        fig = None
        if decoded_chars_list:
            decoded_output = decoded_chars_list[0]
            fig = plot_entropies(
                batch_patch_lengths[0],
                batch_scores[0],
                decoded_output,
                threshold=patcher.threshold
            )

        print("Processing and decoding complete.")
        # --- End Processing ---

        return fig

    except FileNotFoundError as e:
        print(f"Error: {e}")
        # raise gr.Error(str(e)) # Display specific error in Gradio UI
        return f"Error: {str(e)}" # Return error as text output
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        import traceback
        traceback.print_exc()
        # raise gr.Error(f"An error occurred during processing: {e}")
        return f"An unexpected error occurred: {e}" # Return error as text output


iface = gr.Interface(
    fn=process_text,
    inputs=gr.Textbox(
        label="Input Prompt",
        placeholder="Enter your text here..."
    ),
    outputs=gr.Plot(label="Entropy Plot"),
    title="ByteLatent Text Processor",
    description="Enter text to process it with the ByteLatent model ('blt-1b' by default). The decoded output will be shown.",
    allow_flagging="never",
)

with gr.Blocks() as iface:
    gr.Markdown("# ByteLatent Entropy Visualizer") # Title
    gr.Markdown(
        "Process any prompt (limited to 512 bytes) with the 100M entropy patcher model "
        "and visualize the token entropies plot below.<br><br>" # Updated description
        "NOTE: this implementation differs slightly by excluding local attention so we limit "
        "the characters limit to 512 to avoid any deviation.",
        line_breaks=True
    )

    with gr.Column():
        prompt_input = gr.Textbox(
            label="Input Prompt",
            value="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
            placeholder="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
            max_length=512
        )
        submit_button = gr.Button("Generate Plot") # Add button
        plot_output = gr.Plot(label="Entropy w Threshold") # Output component

    # Define the action when the button is clicked
    submit_button.click(
        fn=process_text,
        inputs=prompt_input,      # Input component(s)
        outputs=plot_output       # Output component(s)
    )

# --- Launch the Gradio App ---
if __name__ == "__main__":
    ensure_present(["blt-1b"])
    iface.launch()