erndgn commited on
Commit
144aa8f
·
verified ·
1 Parent(s): 41d619b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +76 -84
  2. requirements.txt +6 -4
app.py CHANGED
@@ -1,38 +1,21 @@
1
  import spaces
2
-
3
- import time
4
  from threading import Thread
5
-
6
  import gradio as gr
7
  import torch
8
  from PIL import Image
9
- from transformers import AutoProcessor
10
- from llava.constants import (
11
- IMAGE_TOKEN_INDEX,
12
- DEFAULT_IMAGE_TOKEN,
13
- DEFAULT_IM_START_TOKEN,
14
- DEFAULT_IM_END_TOKEN,
15
- IMAGE_PLACEHOLDER,
16
- )
17
- from llava.model.builder import load_pretrained_model
18
- from llava.utils import disable_torch_init
19
- from llava.mm_utils import (
20
- process_images,
21
- tokenizer_image_token,
22
- get_model_name_from_path,
23
- )
24
  from io import BytesIO
25
  import requests
26
  import os
27
- from conversation import Conversation, SeparatorStyle
28
 
29
  model_id = "ytu-ce-cosmos/Turkish-LLaVA-v0.1"
30
 
31
- disable_torch_init()
32
- model_name = get_model_name_from_path(model_id)
33
- tokenizer, model, image_processor, context_len = load_pretrained_model(
34
- model_id, None, model_name
35
- )
36
 
37
  def load_image(image_file):
38
  if image_file.startswith("http") or image_file.startswith("https"):
@@ -44,63 +27,13 @@ def load_image(image_file):
44
  raise FileNotFoundError(f"Görüntü dosyası {image_file} bulunamadı.")
45
  return image
46
 
47
- def infer_single_image(model_id, image_file, prompt):
48
- image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
49
- if IMAGE_PLACEHOLDER in prompt:
50
- if model.config.mm_use_im_start_end:
51
- prompt = re.sub(IMAGE_PLACEHOLDER, image_token_se, prompt)
52
- else:
53
- prompt = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, prompt)
54
- else:
55
- if model.config.mm_use_im_start_end:
56
- prompt = image_token_se + "\n" + prompt
57
- else:
58
- prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
59
-
60
- conv = Conversation(
61
- system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nSen bir yapay zeka asistanısın. Kullanıcı sana bir görev verecek. Amacın görevi olabildiğince sadık bir şekilde tamamlamak. Görevi yerine getirirken adım adım düşün ve adımlarını gerekçelendir.""",
62
- roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
63
- version="llama3",
64
- messages=[],
65
- offset=0,
66
- sep_style=SeparatorStyle.MPT,
67
- sep="<|eot_id|>",
68
- )
69
- conv.append_message(conv.roles[0], prompt)
70
- conv.append_message(conv.roles[1], None)
71
- full_prompt = conv.get_prompt()
72
-
73
- print("full prompt: ", full_prompt)
74
-
75
- image = load_image(image_file)
76
- image_tensor = process_images(
77
- [image],
78
- image_processor,
79
- model.config
80
- ).to(model.device, dtype=torch.float16)
81
-
82
- input_ids = (
83
- tokenizer_image_token(full_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
84
- .unsqueeze(0)
85
- .cuda()
86
- )
87
-
88
- with torch.inference_mode():
89
- output_ids = model.generate(
90
- input_ids,
91
- images=image_tensor,
92
- image_sizes=[image.size],
93
- do_sample=False,
94
- max_new_tokens=512,
95
- use_cache=True,
96
- )
97
-
98
- output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
99
- return output
100
-
101
  @spaces.GPU
 
102
  def bot_streaming(message, history):
103
  print(message)
 
 
 
104
  if message["files"]:
105
  if type(message["files"][-1]) == dict:
106
  image = message["files"][-1]["path"]
@@ -110,19 +43,78 @@ def bot_streaming(message, history):
110
  for hist in history:
111
  if type(hist[0]) == tuple:
112
  image = hist[0][0]
 
113
  try:
114
  if image is None:
115
- gr.Error("LLaVA'nın çalışması için bir resim yüklemeniz gerekir.")
 
116
  except NameError:
117
- gr.Error("LLaVA'nın çalışması için bir resim yüklemeniz gerekir.")
 
118
 
119
- prompt = message['text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- result = infer_single_image(model_id, image, prompt)
 
 
 
 
 
 
122
 
123
- print(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- yield result
 
 
 
 
 
 
126
 
127
  chatbot = gr.Chatbot(scale=1)
128
  chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Mesaj girin veya dosya yükleyin...", show_label=False)
 
1
  import spaces
 
 
2
  from threading import Thread
 
3
  import gradio as gr
4
  import torch
5
  from PIL import Image
6
+ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, LlavaForConditionalGeneration, TextIteratorStreamer
7
+ import torchvision.transforms.functional as TVF
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from io import BytesIO
9
  import requests
10
  import os
 
11
 
12
  model_id = "ytu-ce-cosmos/Turkish-LLaVA-v0.1"
13
 
14
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
15
+ assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Expected PreTrainedTokenizer, got {type(tokenizer)}"
16
+
17
+ model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
18
+ assert isinstance(model, LlavaForConditionalGeneration), f"Expected LlavaForConditionalGeneration, got {type(model)}"
19
 
20
  def load_image(image_file):
21
  if image_file.startswith("http") or image_file.startswith("https"):
 
27
  raise FileNotFoundError(f"Görüntü dosyası {image_file} bulunamadı.")
28
  return image
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  @spaces.GPU
31
+ @torch.no_grad()
32
  def bot_streaming(message, history):
33
  print(message)
34
+ torch.cuda.empty_cache()
35
+
36
+ image = None
37
  if message["files"]:
38
  if type(message["files"][-1]) == dict:
39
  image = message["files"][-1]["path"]
 
43
  for hist in history:
44
  if type(hist[0]) == tuple:
45
  image = hist[0][0]
46
+
47
  try:
48
  if image is None:
49
+ yield "LLaVA'nın çalışması için bir resim yüklemeniz gerekir."
50
+ return
51
  except NameError:
52
+ yield "LLaVA'nın çalışması için bir resim yüklemeniz gerekir."
53
+ return
54
 
55
+ prompt = message['text'].strip()
56
+
57
+ image_pil = load_image(image)
58
+
59
+ if image_pil.size != (336, 336):
60
+ image_pil = image_pil.resize((336, 336), Image.LANCZOS)
61
+ image_pil = image_pil.convert("RGB")
62
+
63
+ pixel_values = TVF.pil_to_tensor(image_pil)
64
+ pixel_values = pixel_values.unsqueeze(0).to("cuda")
65
+ pixel_values = pixel_values / 255.0
66
+ pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
67
+ pixel_values = pixel_values.to(torch.bfloat16)
68
+
69
+ convo = [
70
+ {
71
+ "role": "system",
72
+ "content": "Sen bir yapay zeka asistanısın. Kullanıcı sana bir görev verecek. Amacın görevi olabildiğince sadık bir şekilde tamamlamak. Görevi yerine getirirken adım adım düşün ve adımlarını gerekçelendir."
73
+ },
74
+ {
75
+ "role": "user",
76
+ "content": prompt,
77
+ },
78
+ ]
79
+
80
+ convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
81
+
82
+ convo_tokens = tokenizer.encode(convo_string, add_special_tokens=False, truncation=False)
83
 
84
+ input_tokens = []
85
+ for token in convo_tokens:
86
+ if hasattr(model.config, 'image_token_index') and token == model.config.image_token_index:
87
+ seq_length = getattr(model.config, 'image_seq_length', 576)
88
+ input_tokens.extend([model.config.image_token_index] * seq_length)
89
+ else:
90
+ input_tokens.append(token)
91
 
92
+ input_ids = torch.tensor(input_tokens, dtype=torch.long)
93
+ attention_mask = torch.ones_like(input_ids)
94
+
95
+ input_ids = input_ids.unsqueeze(0).to("cuda")
96
+ attention_mask = attention_mask.unsqueeze(0).to("cuda")
97
+
98
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
99
+
100
+ generate_kwargs = dict(
101
+ input_ids=input_ids,
102
+ pixel_values=pixel_values,
103
+ attention_mask=attention_mask,
104
+ max_new_tokens=512,
105
+ do_sample=False,
106
+ suppress_tokens=None,
107
+ use_cache=True,
108
+ streamer=streamer,
109
+ )
110
 
111
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
112
+ t.start()
113
+
114
+ outputs = []
115
+ for text in streamer:
116
+ outputs.append(text)
117
+ yield "".join(outputs)
118
 
119
  chatbot = gr.Chatbot(scale=1)
120
  chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Mesaj girin veya dosya yükleyin...", show_label=False)
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
- llava-torch
2
- spaces
3
- torch
4
- torchvision
 
 
 
1
+ huggingface_hub==0.30.1
2
+ accelerate
3
+ torch
4
+ transformers==4.51.0
5
+ sentencepiece
6
+ torchvision