Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,282 Bytes
9b25f0e b4fa047 953563e b4fa047 c5d3fee b4fa047 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import gradio as gr
import spaces
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from huggingface_hub import snapshot_download
from pathlib import Path
# モデルのダウンロードと準備
mistral_models_path = Path.home().joinpath('mistral_models', 'Pixtral')
mistral_models_path.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id="mistral-community/pixtral-12b-240910",
allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
local_dir=mistral_models_path)
# トークナイザーとモデルのロード
tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
model = Transformer.from_folder(mistral_models_path)
# 推論処理
@spaces.GPU
def mistral_inference(prompt, image_url):
completion_request = ChatCompletionRequest(
messages=[UserMessage(content=[ImageURLChunk(image_url=image_url), TextChunk(text=prompt)])]
)
encoded = tokenizer.encode_chat_completion(completion_request)
images = encoded.images
tokens = encoded.tokens
out_tokens, _ = generate([tokens], model, images=[images], max_tokens=1024, temperature=0.35, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.decode(out_tokens[0])
return result
# Gradio インターフェース
def process_input(text, image_url):
result = mistral_inference(text, image_url)
return result
with gr.Blocks() as demo:
gr.Markdown("## Pixtralモデルによる画像説明生成")
with gr.Row():
text_input = gr.Textbox(label="テキストプロンプト", placeholder="例: Describe the image.")
image_input = gr.Textbox(label="画像URL", placeholder="例: https://example.com/image.png")
result_output = gr.Textbox(label="モデルの出力結果")
submit_button = gr.Button("推論を実行")
submit_button.click(process_input, inputs=[text_input, image_input], outputs=result_output)
demo.launch() |