File size: 4,956 Bytes
949310d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os, tempfile, time
import gradio as gr
from tool.test import run_autotune_pipeline, DATA_DIR
# ---------- Core callback ----------
def get_test_text(test_file, test_data_input):
if test_file is not None:
if hasattr(test_file, "read"):
return test_file.read().decode("utf-8")
elif hasattr(test_file, "data"):
return test_file.data if isinstance(test_file.data, str) else test_file.data.decode("utf-8")
elif hasattr(test_file, "name") and os.path.exists(test_file.name):
with open(test_file.name, "r", encoding="utf-8") as f:
return f.read()
# fallback to textbox
return test_data_input or ""
def generate_kernel(text_input, test_data_input, test_file, n_iters, progress=gr.Progress()):
"""
text_input : string from textbox (NL description or base CUDA code)
test_data_input: test data (variable name, data)
file_input : gr.File upload object (or None)
Returns : (kernel_code_str, downloadable_file_path)
"""
progress((0, n_iters), desc="Initializing...")
# 1) Select input source
if not text_input.strip():
return "⚠️ Please paste a description or baseline CUDA code."
# td = tempfile.mkdtemp(prefix="auto_")
# # ------- select test data source -------
# if test_file is not None and test_file.size > 0:
# test_text = test_file.read().decode("utf-8")
# elif test_data_input.strip():
# test_text = test_data_input
# else:
# return "Test data required: either fill Test Data Input or upload a .txt file.", "", None
# src_path = os.path.join(td, f"input_{int(time.time())}.txt")
# test_path = os.path.join(td, f"test_data_{int(time.time())}.txt")
# with open(src_path, "w") as f:
# f.write(text_input)
# with open(test_path, "w") as f:
# f.write(test_data_input or "")
# if test_file is not None:
# test_text = test_file.read().decode("utf-8")
# else:
# test_text = test_data_input
test_text = get_test_text(test_file, test_data_input)
if not test_text.strip():
return "⚠️ Test data required."
best_code = ""
for info in run_autotune_pipeline(
input_code=text_input,
test_data_input=test_text,
test_file=None,
bin_dir=DATA_DIR,
max_iterations=int(n_iters)
):
# 1) update progress bar (if iteration known)
if info["iteration"] is not None:
# print(f"Iteration {info['iteration']} / {n_iters}: {info['message']}")
progress((info["iteration"], n_iters), desc=info["message"])
# 3) kernel output only when we get new code
if info["code"]:
best_code = info["code"]
# TBD: download button
return best_code
# ---------- Gradio UI ----------
with gr.Blocks(
title="KernelPilot",
theme=gr.themes.Soft(
text_size="lg",
font=[
"system-ui",
"-apple-system",
"BlinkMacSystemFont",
"Segoe UI",
"Roboto",
"Helvetica Neue",
"Arial",
"Noto Sans",
"sans-serif"
])) as demo:
gr.Markdown(
"""# 🚀 KernelPilot Optimizer
Enter a code, test data, then click **Generate** to obtain the optimized kernel function."""
)
with gr.Row():
txt_input = gr.Textbox(
label="📝 Input",
lines=10,
placeholder="Enter the code",
scale=3
)
level = gr.Number(
label="Optimazation Level",
minimum=1,
maximum=5,
value=5,
step=1,
scale=1
)
with gr.Row():
test_data_input = gr.Textbox(
label="Test Data Input",
lines=10,
placeholder="<number_of_test_cases>\n<number_of_variables>\n\n<variable_1_name>\n<variable_1_testcase_1_data>\n<variable_1_testcase_2_data>\n...\n<variable_1_testcase_N_data>\n\n<variable_2_name>\n<variable_2_testcase_1_data>\n...\n<variable_2_testcase_N_data>\n\n...",
scale=2
)
test_file = gr.File(
label="Upload Test Data (.txt)",
file_types=["text"],
scale=1
)
gen_btn = gr.Button("⚡ Generate")
kernel_output = gr.Code(
label="🎯 Tuned CUDA Kernel",
language="cpp"
)
gen_btn.click(
fn=generate_kernel,
inputs=[txt_input, test_data_input, test_file, level],
outputs=[kernel_output],
queue=True, # keeps requests queued
show_progress=True, # show progress bar
show_progress_on=kernel_output # update log box with progress
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=1, max_size=50)
demo.launch(server_name="0.0.0.0", server_port=7860) |