DoctorSlimm commited on
Commit
a539d3b
·
verified ·
1 Parent(s): 9af21de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -63
app.py CHANGED
@@ -1,81 +1,137 @@
1
  import os
2
- import spaces
3
  import torch
 
4
  import gradio as gr
 
 
 
 
 
5
 
 
6
 
7
- # cpu
8
 
9
- zero = torch.Tensor([0]).cuda()
10
- print(zero.device) # <-- 'cpu' 🤔
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # gpu
14
- model = None
15
 
16
  @spaces.GPU
17
- def greet(prompts, separator):
18
- # print(zero.device) # <-- 'cuda:0' 🤗
19
- from vllm import SamplingParams, LLM
20
- from transformers.utils import move_cache
21
- from huggingface_hub import snapshot_download, login
22
-
23
- global model
24
-
25
- if model is None:
26
- LLM_MODEL_ID = "DoctorSlimm/trim-music-31"
27
- # LLM_MODEL_ID = "mistral-community/Mistral-7B-v0.2"
28
- # LLM_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"
29
- os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
30
- fp = snapshot_download(LLM_MODEL_ID, token=os.getenv('HF_TOKEN'), revision='main')
31
- move_cache()
32
- model = LLM(fp)
33
 
34
- sampling_params = dict(
35
- temperature = 0.01,
36
- ignore_eos = False,
37
- max_tokens = int(512 * 2)
38
- )
39
- sampling_params = SamplingParams(**sampling_params)
40
-
41
- multi_prompt = False
42
- separator = separator.strip()
43
- if separator in prompts:
44
- multi_prompt = True
45
- prompts = prompts.split(separator)
46
- else:
47
- prompts = [prompts]
48
- for idx, pt in enumerate(prompts):
49
- print()
50
- print(f'[{idx}]:')
51
- print(pt)
52
-
53
- model_outputs = model.generate(prompts, sampling_params)
54
- generations = []
55
- for output in model_outputs:
56
- for outputs in output.outputs:
57
- generations.append(outputs.text)
58
- if multi_prompt:
59
- return separator.join(generations)
60
- return generations[0]
61
 
62
 
63
  ## make predictions via api ##
64
  # https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app
65
 
66
  demo = gr.Interface(
67
- fn=greet,
68
- inputs=[
69
- gr.Text(
70
- value='hello sir!<SEP>bonjour madame...',
71
- placeholder='hello sir!<SEP>bonjour madame...',
72
- label='list of prompts separated by separator'
73
- ),
74
- gr.Text(
75
- value='<SEP>',
76
- placeholder='<SEP>',
77
- label='separator for your prompts'
78
- )],
79
- outputs=gr.Text()
 
 
 
 
 
 
 
 
 
 
80
  )
81
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  import torch
3
+ import spaces
4
  import gradio as gr
5
+ from PIL import Image
6
+ from transformers.utils import move_cache
7
+ from huggingface_hub import snapshot_download
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+
10
 
11
+ # Load the model and processor
12
 
13
+ MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B"
14
 
15
+ os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
16
+ MODEL_PATH = snapshot_download(MODEL_PATH)
17
+ move_cache()
18
 
19
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+ TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained(
23
+ MODEL_PATH,
24
+ trust_remote_code=True
25
+ )
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ MODEL_PATH,
28
+ torch_dtype=TORCH_TYPE,
29
+ trust_remote_code=True,
30
+ ).to(DEVICE).eval()
31
 
 
 
32
 
33
  @spaces.GPU
34
+ def generate_caption(image, prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # Process the image and the prompt
37
+ text_only_template = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
38
+ # inputs = processor(texts=[prompt], images=[image], return_tensors="pt").to('cuda') # move inputs to cuda
39
+
40
+
41
+
42
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
  ## make predictions via api ##
46
  # https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app
47
 
48
  demo = gr.Interface(
49
+ fn=generate_caption,
50
+ inputs=[gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Prompt", value="Describe the image in great detail")],
51
+ outputs=gr.Textbox(label="Generated Caption"),
52
+ description=description
53
+ )
54
+
55
+ # Launch the interface
56
+ demo.launch(share=True)
57
+
58
+
59
+
60
+ ####### ML CODE #######
61
+ import torch
62
+ from PIL import Image
63
+ from transformers import AutoModelForCausalLM, AutoTokenizer
64
+
65
+ MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B"
66
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
67
+ TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
68
+
69
+ tokenizer = AutoTokenizer.from_pretrained(
70
+ MODEL_PATH,
71
+ trust_remote_code=True
72
  )
73
+ model = AutoModelForCausalLM.from_pretrained(
74
+ MODEL_PATH,
75
+ torch_dtype=TORCH_TYPE,
76
+ trust_remote_code=True,
77
+ ).to(DEVICE).eval()
78
+
79
+ text_only_template = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
80
+
81
+ while True:
82
+ image_path = input("image path >>>>> ")
83
+ if image_path == '':
84
+ print('You did not enter image path, the following will be a plain text conversation.')
85
+ image = None
86
+ text_only_first_query = True
87
+ else:
88
+ image = Image.open(image_path).convert('RGB')
89
+
90
+ history = []
91
+
92
+ while True:
93
+ query = input("Human:")
94
+ if query == "clear":
95
+ break
96
+
97
+ if image is None:
98
+ if text_only_first_query:
99
+ query = text_only_template.format(query)
100
+ text_only_first_query = False
101
+ else:
102
+ old_prompt = ''
103
+ for _, (old_query, response) in enumerate(history):
104
+ old_prompt += old_query + " " + response + "\n"
105
+ query = old_prompt + "USER: {} ASSISTANT:".format(query)
106
+ if image is None:
107
+ input_by_model = model.build_conversation_input_ids(
108
+ tokenizer,
109
+ query=query,
110
+ history=history,
111
+ template_version='chat'
112
+ )
113
+ else:
114
+ input_by_model = model.build_conversation_input_ids(
115
+ tokenizer,
116
+ query=query,
117
+ history=history,
118
+ images=[image],
119
+ template_version='chat'
120
+ )
121
+ inputs = {
122
+ 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
123
+ 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
124
+ 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
125
+ 'images': [[input_by_model['images'][0].to(DEVICE).to(TORCH_TYPE)]] if image is not None else None,
126
+ }
127
+ gen_kwargs = {
128
+ "max_new_tokens": 2048,
129
+ "pad_token_id": 128002,
130
+ }
131
+ with torch.no_grad():
132
+ outputs = model.generate(**inputs, **gen_kwargs)
133
+ outputs = outputs[:, inputs['input_ids'].shape[1]:]
134
+ response = tokenizer.decode(outputs[0])
135
+ response = response.split("<|end_of_text|>")[0]
136
+ print("\nCogVLM2:", response)
137
+ history.append((query, response))