GraphGen / app.py
chenzihong-gavin
refactor: refact app.py
e453a65
import os
import sys
import json
import tempfile
import pandas as pd
import gradio as gr
from gradio_i18n import Translate, gettext as _
from webui.base import GraphGenParams
from webui.test_api import test_api_connection
from webui.cache_utils import setup_workspace, cleanup_workspace
from webui.count_tokens import count_tokens
# pylint: disable=wrong-import-position
root_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(root_dir)
from graphgen.graphgen import GraphGen
from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy
from graphgen.models.llm.limitter import RPM, TPM
from graphgen.utils import set_logger
css = """
.center-row {
display: flex;
justify-content: center;
align-items: center;
}
"""
def init_graph_gen(config: dict, env: dict) -> GraphGen:
# Set up working directory
log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
set_logger(log_file, if_stream=False)
graph_gen = GraphGen(
working_dir=working_dir
)
# Set up LLM clients
graph_gen.synthesizer_llm_client = OpenAIModel(
model_name=env.get("SYNTHESIZER_MODEL", ""),
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
api_key=env.get("SYNTHESIZER_API_KEY", ""),
request_limit=True,
rpm= RPM(env.get("RPM", 1000)),
tpm= TPM(env.get("TPM", 50000)),
)
graph_gen.trainee_llm_client = OpenAIModel(
model_name=env.get("TRAINEE_MODEL", ""),
base_url=env.get("TRAINEE_BASE_URL", ""),
api_key=env.get("TRAINEE_API_KEY", ""),
request_limit=True,
rpm= RPM(env.get("RPM", 1000)),
tpm= TPM(env.get("TPM", 50000)),
)
graph_gen.tokenizer_instance = Tokenizer(
config.get("tokenizer", "cl100k_base"))
strategy_config = config.get("traverse_strategy", {})
graph_gen.traverse_strategy = TraverseStrategy(
qa_form=config.get("qa_form"),
expand_method=strategy_config.get("expand_method"),
bidirectional=strategy_config.get("bidirectional"),
max_extra_edges=strategy_config.get("max_extra_edges"),
max_tokens=strategy_config.get("max_tokens"),
max_depth=strategy_config.get("max_depth"),
edge_sampling=strategy_config.get("edge_sampling"),
isolated_node_strategy=strategy_config.get("isolated_node_strategy"),
loss_strategy=str(strategy_config.get("loss_strategy"))
)
return graph_gen
# pylint: disable=too-many-statements
def run_graphgen(params, progress=gr.Progress()):
def sum_tokens(client):
return sum(u["total_tokens"] for u in client.token_usage)
config = {
"if_trainee_model": params.if_trainee_model,
"input_file": params.input_file,
"tokenizer": params.tokenizer,
"qa_form": params.qa_form,
"web_search": False,
"quiz_samples": params.quiz_samples,
"traverse_strategy": {
"bidirectional": params.bidirectional,
"expand_method": params.expand_method,
"max_extra_edges": params.max_extra_edges,
"max_tokens": params.max_tokens,
"max_depth": params.max_depth,
"edge_sampling": params.edge_sampling,
"isolated_node_strategy": params.isolated_node_strategy,
"loss_strategy": params.loss_strategy
},
"chunk_size": params.chunk_size,
}
env = {
"SYNTHESIZER_BASE_URL": params.synthesizer_url,
"SYNTHESIZER_MODEL": params.synthesizer_model,
"TRAINEE_BASE_URL": params.trainee_url,
"TRAINEE_MODEL": params.trainee_model,
"SYNTHESIZER_API_KEY": params.api_key,
"TRAINEE_API_KEY": params.trainee_api_key,
"RPM": params.rpm,
"TPM": params.tpm,
}
# Test API connection
test_api_connection(env["SYNTHESIZER_BASE_URL"],
env["SYNTHESIZER_API_KEY"], env["SYNTHESIZER_MODEL"])
if config['if_trainee_model']:
test_api_connection(env["TRAINEE_BASE_URL"],
env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"])
# Initialize GraphGen
graph_gen = init_graph_gen(config, env)
graph_gen.clear()
graph_gen.progress_bar = progress
try:
# Load input data
file = config['input_file']
if isinstance(file, list):
file = file[0]
data = []
if file.endswith(".jsonl"):
data_type = "raw"
with open(file, "r", encoding='utf-8') as f:
data.extend(json.loads(line) for line in f)
elif file.endswith(".json"):
data_type = "chunked"
with open(file, "r", encoding='utf-8') as f:
data.extend(json.load(f))
elif file.endswith(".txt"):
# 读取文件后根据chunk_size转成raw格式的数据
data_type = "raw"
content = ""
with open(file, "r", encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
content += line.strip() + " "
size = int(config.get("chunk_size", 512))
chunks = [
content[i:i + size] for i in range(0, len(content), size)
]
data.extend([{"content": chunk} for chunk in chunks])
else:
raise ValueError(f"Unsupported file type: {file}")
# Process the data
graph_gen.insert(data, data_type)
if config['if_trainee_model']:
# Generate quiz
graph_gen.quiz(max_samples=config['quiz_samples'])
# Judge statements
graph_gen.judge()
else:
graph_gen.traverse_strategy.edge_sampling = "random"
# Skip judge statements
graph_gen.judge(skip=True)
# Traverse graph
graph_gen.traverse()
# Save output
output_data = graph_gen.qa_storage.data
with tempfile.NamedTemporaryFile(
mode="w",
suffix=".jsonl",
delete=False,
encoding="utf-8") as tmpfile:
json.dump(output_data, tmpfile, ensure_ascii=False)
output_file = tmpfile.name
synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client)
trainee_tokens = sum_tokens(graph_gen.trainee_llm_client) if config['if_trainee_model'] else 0
total_tokens = synthesizer_tokens + trainee_tokens
data_frame = params.token_counter
try:
_update_data = [
[
data_frame.iloc[0, 0],
data_frame.iloc[0, 1],
str(total_tokens)
]
]
new_df = pd.DataFrame(
_update_data,
columns=data_frame.columns
)
data_frame = new_df
except Exception as e:
raise gr.Error(f"DataFrame operation error: {str(e)}")
return output_file, gr.DataFrame(label='Token Stats',
headers=["Source Text Token Count", "Expected Token Usage", "Token Used"],
datatype="str",
interactive=False,
value=data_frame,
visible=True,
wrap=True)
except Exception as e: # pylint: disable=broad-except
raise gr.Error(f"Error occurred: {str(e)}")
finally:
# Clean up workspace
cleanup_workspace(graph_gen.working_dir)
with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
css=css) as demo):
# Header
gr.Image(value="https://github.com/open-sciencelab/GraphGen/blob/main/resources/images/logo.png?raw=true",
label="GraphGen Banner",
elem_id="banner",
interactive=False,
container=False,
show_download_button=False,
show_fullscreen_button=False)
lang_btn = gr.Radio(
choices=[
("English", "en"),
("简体中文", "zh"),
],
value="en",
# label=_("Language"),
render=False,
container=False,
elem_classes=["center-row"],
)
gr.HTML("""
<div style="display: flex; gap: 8px; margin-left: auto; align-items: center; justify-content: center;">
<a href="https://github.com/open-sciencelab/GraphGen/releases">
<img src="https://img.shields.io/badge/Version-v0.1.0-blue" alt="Version">
</a>
<a href="https://graphgen-docs.example.com">
<img src="https://img.shields.io/badge/Docs-Latest-brightgreen" alt="Documentation">
</a>
<a href="https://github.com/open-sciencelab/GraphGen/issues/10">
<img src="https://img.shields.io/github/stars/open-sciencelab/GraphGen?style=social" alt="GitHub Stars">
</a>
<a href="https://arxiv.org/abs/2505.20416">
<img src="https://img.shields.io/badge/arXiv-pdf-yellow" alt="arXiv">
</a>
</div>
""")
with Translate(
os.path.join(root_dir, 'webui', 'translation.json'),
lang_btn,
placeholder_langs=["en", "zh"],
persistant=
False, # True to save the language setting in the browser. Requires gradio >= 5.6.0
):
lang_btn.render()
gr.Markdown(
value = "# " + _("Title") + "\n\n" + \
"### [GraphGen](https://github.com/open-sciencelab/GraphGen) " + _("Intro")
)
if_trainee_model = gr.Checkbox(label=_("Use Trainee Model"),
value=False,
interactive=True)
with gr.Accordion(label=_("Model Config"), open=False):
synthesizer_url = gr.Textbox(label="Synthesizer URL",
value="https://api.siliconflow.cn/v1",
info=_("Synthesizer URL Info"),
interactive=True)
synthesizer_model = gr.Textbox(label="Synthesizer Model",
value="Qwen/Qwen2.5-7B-Instruct",
info=_("Synthesizer Model Info"),
interactive=True)
trainee_url = gr.Textbox(label="Trainee URL",
value="https://api.siliconflow.cn/v1",
info=_("Trainee URL Info"),
interactive=True,
visible=if_trainee_model.value is True)
trainee_model = gr.Textbox(
label="Trainee Model",
value="Qwen/Qwen2.5-7B-Instruct",
info=_("Trainee Model Info"),
interactive=True,
visible=if_trainee_model.value is True)
trainee_api_key = gr.Textbox(
label=_("SiliconCloud Token for Trainee Model"),
type="password",
value="",
info="https://cloud.siliconflow.cn/account/ak",
visible=if_trainee_model.value is True)
with gr.Accordion(label=_("Generation Config"), open=False):
chunk_size = gr.Slider(label="Chunk Size",
minimum=256,
maximum=4096,
value=512,
step=256,
interactive=True)
tokenizer = gr.Textbox(label="Tokenizer",
value="cl100k_base",
interactive=True)
qa_form = gr.Radio(choices=["atomic", "multi_hop", "aggregated"],
label="QA Form",
value="aggregated",
interactive=True)
quiz_samples = gr.Number(label="Quiz Samples",
value=2,
minimum=1,
interactive=True,
visible=if_trainee_model.value is True)
bidirectional = gr.Checkbox(label="Bidirectional",
value=True,
interactive=True)
expand_method = gr.Radio(choices=["max_width", "max_tokens"],
label="Expand Method",
value="max_tokens",
interactive=True)
max_extra_edges = gr.Slider(
minimum=1,
maximum=10,
value=5,
label="Max Extra Edges",
step=1,
interactive=True,
visible=expand_method.value == "max_width")
max_tokens = gr.Slider(minimum=64,
maximum=1024,
value=256,
label="Max Tokens",
step=64,
interactive=True,
visible=(expand_method.value
!= "max_width"))
max_depth = gr.Slider(minimum=1,
maximum=5,
value=2,
label="Max Depth",
step=1,
interactive=True)
edge_sampling = gr.Radio(
choices=["max_loss", "min_loss", "random"],
label="Edge Sampling",
value="max_loss",
interactive=True,
visible=if_trainee_model.value is True)
isolated_node_strategy = gr.Radio(choices=["add", "ignore"],
label="Isolated Node Strategy",
value="ignore",
interactive=True)
loss_strategy = gr.Radio(choices=["only_edge", "both"],
label="Loss Strategy",
value="only_edge",
interactive=True)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
api_key = gr.Textbox(
label=_("SiliconCloud Token"),
type="password",
value="",
info="https://cloud.siliconflow.cn/account/ak")
with gr.Column(scale=1):
test_connection_btn = gr.Button(_("Test Connection"))
with gr.Blocks():
with gr.Row(equal_height=True):
with gr.Column():
rpm = gr.Slider(
label="RPM",
minimum=10,
maximum=10000,
value=1000,
step=100,
interactive=True,
visible=True)
with gr.Column():
tpm = gr.Slider(
label="TPM",
minimum=5000,
maximum=5000000,
value=50000,
step=1000,
interactive=True,
visible=True)
with gr.Blocks():
with gr.Row(equal_height=True):
with gr.Column(scale=1):
upload_file = gr.File(
label=_("Upload File"),
file_count="single",
file_types=[".txt", ".json", ".jsonl"],
interactive=True,
)
examples_dir = os.path.join(root_dir, 'webui', 'examples')
gr.Examples(examples=[
[os.path.join(examples_dir, "txt_demo.txt")],
[os.path.join(examples_dir, "raw_demo.jsonl")],
[os.path.join(examples_dir, "chunked_demo.json")],
],
inputs=upload_file,
label=_("Example Files"),
examples_per_page=3)
with gr.Column(scale=1):
output = gr.File(
label="Output(See Github FAQ)",
file_count="single",
interactive=False,
)
with gr.Blocks():
token_counter = gr.DataFrame(label='Token Stats',
headers=["Source Text Token Count", "Estimated Token Usage", "Token Used"],
datatype="str",
interactive=False,
visible=False,
wrap=True)
submit_btn = gr.Button(_("Run GraphGen"))
# Test Connection
test_connection_btn.click(
test_api_connection,
inputs=[synthesizer_url, api_key, synthesizer_model],
outputs=[])
if if_trainee_model.value:
test_connection_btn.click(test_api_connection,
inputs=[trainee_url, api_key, trainee_model],
outputs=[])
expand_method.change(lambda method:
(gr.update(visible=method == "max_width"),
gr.update(visible=method != "max_width")),
inputs=expand_method,
outputs=[max_extra_edges, max_tokens])
if_trainee_model.change(
lambda use_trainee: [gr.update(visible=use_trainee)] * 5,
inputs=if_trainee_model,
outputs=[trainee_url, trainee_model, quiz_samples, edge_sampling, trainee_api_key])
upload_file.change(
lambda x: (gr.update(visible=True)),
inputs=[upload_file],
outputs=[token_counter],
).then(
count_tokens,
inputs=[upload_file, tokenizer, token_counter],
outputs=[token_counter],
)
# run GraphGen
submit_btn.click(
lambda x: (gr.update(visible=False)),
inputs=[token_counter],
outputs=[token_counter],
)
submit_btn.click(
lambda *args: run_graphgen(GraphGenParams(
if_trainee_model=args[0],
input_file=args[1],
tokenizer=args[2],
qa_form=args[3],
bidirectional=args[4],
expand_method=args[5],
max_extra_edges=args[6],
max_tokens=args[7],
max_depth=args[8],
edge_sampling=args[9],
isolated_node_strategy=args[10],
loss_strategy=args[11],
synthesizer_url=args[12],
synthesizer_model=args[13],
trainee_model=args[14],
api_key=args[15],
chunk_size=args[16],
rpm=args[17],
tpm=args[18],
quiz_samples=args[19],
trainee_url=args[20],
trainee_api_key=args[21],
token_counter=args[22],
)),
inputs=[
if_trainee_model, upload_file, tokenizer, qa_form,
bidirectional, expand_method, max_extra_edges, max_tokens,
max_depth, edge_sampling, isolated_node_strategy,
loss_strategy, synthesizer_url, synthesizer_model, trainee_model,
api_key, chunk_size, rpm, tpm, quiz_samples, trainee_url, trainee_api_key, token_counter
],
outputs=[output, token_counter],
)
if __name__ == "__main__":
demo.queue(api_open=False, default_concurrency_limit=2)
demo.launch(server_name='0.0.0.0')