my-im2svg / app.py
Rinawang's picture
Create app.py
593d8ab verified
raw
history blame
925 Bytes
import gradio as gr
from transformers import AutoProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch
# 加载模型和处理器
model_id = "starvector/starvector-8b-im2svg"
processor = AutoProcessor.from_pretrained(model_id)
model = VisionEncoderDecoderModel.from_pretrained(model_id)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# 定义推理函数
def im2svg(image):
inputs = processor(images=image, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=1024)
generated_svg = processor.batch_decode(outputs, skip_special_tokens=True)[0]
return generated_svg
# 创建 Gradio 界面
demo = gr.Interface(
fn=im2svg,
inputs=gr.Image(type="pil"),
outputs="text",
title="StarVector 8B - Image to SVG",
description="上传图像,将其转化为 SVG 矢量代码。",
allow_flagging="never"
)
demo.launch()