my-im2svg / app.py
Rinawang's picture
Update app.py
2c40e39 verified
import gradio as gr
from transformers import (
SiglipImageProcessor,
RobertaTokenizerFast,
VisionEncoderDecoderModel
)
from PIL import Image
import torch
# 模型 ID
model_id = "starvector/starvector-8b-im2svg"
# 分别加载 image processor 和 tokenizer
image_processor = SiglipImageProcessor.from_pretrained(model_id)
tokenizer = RobertaTokenizerFast.from_pretrained(model_id)
model = VisionEncoderDecoderModel.from_pretrained(model_id)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# 推理函数
def im2svg(image):
inputs = image_processor(images=image, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=1024)
svg_code = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return svg_code
# Gradio UI
demo = gr.Interface(
fn=im2svg,
inputs=gr.Image(type="pil"),
outputs="text",
title="🖼️ StarVector: Image → SVG",
description="上传图像,我将它转化为矢量图(SVG 代码)。适用于简笔画、图标、草图。",
)
demo.launch()