helenai commited on
Commit
cc22261
·
1 Parent(s): a1bcd02

Upload appstream.py

Browse files
Files changed (1) hide show
  1. appstream.py +103 -0
appstream.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+
3
+ import gradio as gr
4
+ import openvino as ov
5
+ from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
6
+ from llava.conversation import conv_templates
7
+ from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
8
+ from llava.model.builder import load_pretrained_model
9
+ from transformers import TextIteratorStreamer
10
+
11
+ css = """
12
+ .text textarea {font-size: 24px !important;}
13
+ .text p {font-size: 24px !important;}
14
+ """
15
+
16
+ model_path = "llava-med-imf16-llmint4"
17
+ # model_path = "llava-med-imint8-llmint4"
18
+ model_name = get_model_name_from_path(model_path)
19
+
20
+ device = "GPU" if "GPU" in ov.Core().available_devices else "CPU"
21
+ image_device = "NPU" if "NPU" in ov.Core().available_devices else device
22
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
23
+ model_path=model_path,
24
+ model_base=None,
25
+ model_name=model_name,
26
+ device=device,
27
+ openvino=True,
28
+ image_device=image_device,
29
+ )
30
+ print("models loaded")
31
+
32
+
33
+ def reset_inputs():
34
+ return None, "", ""
35
+
36
+
37
+ def prepare_inputs_image(image, question):
38
+ conv_mode = "vicuna_v1" # default
39
+ qs = question.replace(DEFAULT_IMAGE_TOKEN, "").strip()
40
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs # model.config.mm_use_im_start_end is False
41
+
42
+ conv = conv_templates[conv_mode].copy()
43
+ conv.append_message(conv.roles[0], qs)
44
+ conv.append_message(conv.roles[1], None)
45
+ prompt = conv.get_prompt()
46
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0)
47
+
48
+ # image = Image.open(image_file)
49
+ image_tensor = process_images([image], image_processor, model.config)[0]
50
+ return input_ids, image_tensor
51
+
52
+
53
+ def run_inference(image, message):
54
+ """
55
+ Function to handle the chat input and generate model responses.
56
+ """
57
+ if not message:
58
+ return ""
59
+
60
+ input_ids, image_tensor = prepare_inputs_image(image, message)
61
+
62
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
63
+ generation_kwargs = {
64
+ "streamer": streamer,
65
+ "input_ids": input_ids,
66
+ "images": image_tensor.unsqueeze(0).half(),
67
+ "do_sample": False,
68
+ "max_new_tokens": 512,
69
+ "use_cache": True,
70
+ }
71
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
72
+ thread.start()
73
+
74
+ # Stream output
75
+ response = ""
76
+ for new_text in streamer:
77
+ response += new_text
78
+ yield response
79
+
80
+
81
+ with gr.Blocks(css=css) as demo:
82
+ gr.Markdown("# LLaVA-Med 1.5 OpenVINO Demo")
83
+
84
+ with gr.Row():
85
+ with gr.Column():
86
+ image_input = gr.Image(type="pil", label="Upload an Image", height=300, width=500)
87
+ with gr.Column():
88
+ text_input = gr.Textbox(label="Enter a Question", elem_classes="text", interactive=True)
89
+ chatbot = gr.Textbox(label="Answer", elem_classes="text")
90
+
91
+ with gr.Row():
92
+ process_button = gr.Button("Process")
93
+ reset_button = gr.Button("Reset")
94
+
95
+ gr.Markdown("NOTE: This OpenVINO model is unvalidated. Results are provisional and may contain errors. Use this demo to explore AI PC and OpenVINO optimizations")
96
+ gr.Markdown("Source model: [microsoft/LLaVA-Med](https://github.com/microsoft/LLaVA-Med). For research purposes only.")
97
+
98
+ process_button.click(run_inference, inputs=[image_input, text_input], outputs=chatbot)
99
+ text_input.submit(run_inference, inputs=[image_input, text_input], outputs=chatbot)
100
+ reset_button.click(reset_inputs, inputs=[], outputs=[image_input, text_input, chatbot])
101
+
102
+ if __name__ == "__main__":
103
+ demo.launch(server_port=7788, server_name="0.0.0.0")