AIBunCho commited on
Commit
c08c192
·
1 Parent(s): 95a67ca

Add application file

Browse files
Files changed (5) hide show
  1. .gitignore +2 -0
  2. app.py +105 -0
  3. images/0.jpg +0 -0
  4. images/1.jpg +0 -0
  5. requirements.txt +5 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ checkpoint-merged
2
+ flagged
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ # CUDA_VISIBLE_DEVICES 環境変数を設定して特定のGPUを使用
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
5
+
6
+ import torch
7
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
8
+ from PIL import Image
9
+ import gradio as gr
10
+ from qwen_vl_utils import process_vision_info # 必要に応じてインポートを調整
11
+
12
+ def load_model():
13
+ """
14
+ マージ済みモデルとプロセッサのロード
15
+ """
16
+ print("マージ済みモデルをロード中...")
17
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
18
+ "AIBunCho/AI_bokete", torch_dtype="auto", device_map="auto",
19
+ )
20
+ processor = AutoProcessor.from_pretrained("AIBunCho/AI_bokete")
21
+ print("マージ済みモデルのロード完了.")
22
+ return model, processor
23
+
24
+ def perform_inference(model, processor, image, prompt):
25
+ """
26
+ 推論の実行
27
+ """
28
+ messages = [
29
+ {
30
+ "role": "user",
31
+ "content": [
32
+ {
33
+ "type": "image",
34
+ "image": image, # プレースホルダー
35
+ },
36
+ {"type": "text", "text": prompt},
37
+ ],
38
+ }
39
+ ]
40
+
41
+ # 画像の前処理
42
+ image = image.convert("RGB")
43
+ image_inputs, video_inputs = process_vision_info(messages)
44
+
45
+ # テキストの準備
46
+ text = processor.apply_chat_template(
47
+ messages, tokenize=False, add_generation_prompt=True
48
+ )
49
+
50
+ # モデル入力の準備
51
+ inputs = processor(
52
+ text=[text],
53
+ images=image_inputs,
54
+ videos=video_inputs,
55
+ padding=True,
56
+ return_tensors="pt",
57
+ )
58
+
59
+ # デバイスへの転送 (cuda:0に統一)
60
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
61
+ model.to(device)
62
+ inputs = {k: v.to(device) for k, v in inputs.items()}
63
+
64
+ # モデルのすべてのパラメータを指定デバイスに移動
65
+ for param in model.parameters():
66
+ param.data = param.data.to(device)
67
+
68
+ # 推論
69
+ with torch.no_grad():
70
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
71
+
72
+ # 生成されたIDをトリム
73
+ generated_ids_trimmed = [
74
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
75
+ ]
76
+
77
+ # 結果のデコード
78
+ output_text = processor.batch_decode(
79
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
80
+ )
81
+
82
+ return output_text[0]
83
+
84
+ def main():
85
+ # モデルとプロセッサのロード
86
+ model, processor = load_model()
87
+
88
+ # プロンプトの設定
89
+ prompt = "<image>画像を見てシュールで面白いことを言ってください。空欄がある場合はそれを埋めるように答えてください。"
90
+
91
+ # Gradioインターフェースの定義
92
+ iface = gr.Interface(
93
+ fn=lambda image: perform_inference(model, processor, image, prompt),
94
+ inputs=gr.Image(type="pil"),
95
+ outputs="text",
96
+ title="Qwen2-VL-7B-Instruct Bokete Inference",
97
+ description="画像をアップロードすると、シュールで面白いキャプションが生成される…かも?",
98
+ examples=[["./images/0.jpg"]],
99
+ )
100
+
101
+ # Gradioアプリケーションの起動
102
+ iface.launch()
103
+
104
+ if __name__ == "__main__":
105
+ main()
images/0.jpg ADDED
images/1.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ pillow
4
+ gradio
5
+ qwen-vl-utils