tricktreat commited on
Commit
5b71c3a
Β·
1 Parent(s): 398dbad
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ logs/
2
+ models
3
+ public/
4
+ *.pyc
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Test
3
  emoji: 😻
4
  colorFrom: gray
5
  colorTo: yellow
 
1
  ---
2
+ title: HuggingGPT
3
  emoji: 😻
4
  colorFrom: gray
5
  colorTo: yellow
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import gradio as gr
3
+ import re
4
+ from diffusers.utils import load_image
5
+ import requests
6
+ from awesome_chat import chat_huggingface
7
+ import os
8
+
9
+ all_messages = []
10
+ OPENAI_KEY = ""
11
+
12
+ os.makedirs("public/images", exist_ok=True)
13
+ os.makedirs("public/audios", exist_ok=True)
14
+ os.makedirs("public/videos", exist_ok=True)
15
+
16
+ def add_message(content, role):
17
+ message = {"role":role, "content":content}
18
+ all_messages.append(message)
19
+
20
+ def extract_medias(message):
21
+ image_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(jpg|jpeg|tiff|gif|png)")
22
+ image_urls = []
23
+ for match in image_pattern.finditer(message):
24
+ if match.group(0) not in image_urls:
25
+ image_urls.append(match.group(0))
26
+
27
+ audio_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(flac|wav)")
28
+ audio_urls = []
29
+ for match in audio_pattern.finditer(message):
30
+ if match.group(0) not in audio_urls:
31
+ audio_urls.append(match.group(0))
32
+
33
+ video_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(mp4)")
34
+ video_urls = []
35
+ for match in video_pattern.finditer(message):
36
+ if match.group(0) not in video_urls:
37
+ video_urls.append(match.group(0))
38
+
39
+ return image_urls, audio_urls, video_urls
40
+
41
+ def set_openai_key(openai_key):
42
+ global OPENAI_KEY
43
+ OPENAI_KEY = openai_key
44
+ return OPENAI_KEY
45
+
46
+ def add_text(messages, message):
47
+ if len(OPENAI_KEY) == 0 or not OPENAI_KEY.startswith("sk-"):
48
+ return messages, "Please set your OpenAI API key first."
49
+ add_message(message, "user")
50
+ messages = messages + [(message, None)]
51
+ image_urls, audio_urls, video_urls = extract_medias(message)
52
+
53
+ for image_url in image_urls:
54
+ if not image_url.startswith("http") and not image_url.startswith("public"):
55
+ image_url = "public/" + image_url
56
+ image = load_image(image_url)
57
+ name = f"public/images/{str(uuid.uuid4())[:4]}.jpg"
58
+ image.save(name)
59
+ messages = messages + [((f"{name}",), None)]
60
+ for audio_url in audio_urls and not audio_url.startswith("public"):
61
+ if not audio_url.startswith("http"):
62
+ audio_url = "public/" + audio_url
63
+ ext = audio_url.split(".")[-1]
64
+ name = f"public/audios/{str(uuid.uuid4()[:4])}.{ext}"
65
+ response = requests.get(audio_url)
66
+ with open(name, "wb") as f:
67
+ f.write(response.content)
68
+ messages = messages + [((f"{name}",), None)]
69
+ for video_url in video_urls and not video_url.startswith("public"):
70
+ if not video_url.startswith("http"):
71
+ video_url = "public/" + video_url
72
+ ext = video_url.split(".")[-1]
73
+ name = f"public/audios/{str(uuid.uuid4()[:4])}.{ext}"
74
+ response = requests.get(video_url)
75
+ with open(name, "wb") as f:
76
+ f.write(response.content)
77
+ messages = messages + [((f"{name}",), None)]
78
+ return messages, ""
79
+
80
+ def bot(messages):
81
+ if len(OPENAI_KEY) == 0 or not OPENAI_KEY.startswith("sk-"):
82
+ return messages
83
+ message = chat_huggingface(all_messages, OPENAI_KEY)["message"]
84
+ image_urls, audio_urls, video_urls = extract_medias(message)
85
+ add_message(message, "assistant")
86
+ messages[-1][1] = message
87
+ for image_url in image_urls:
88
+ image_url = image_url.replace("public/", "")
89
+ messages = messages + [((None, (f"public/{image_url}",)))]
90
+ for audio_url in audio_urls:
91
+ audio_url = audio_url.replace("public/", "")
92
+ messages = messages + [((None, (f"public/{audio_url}",)))]
93
+ for video_url in video_urls:
94
+ video_url = video_url.replace("public/", "")
95
+ messages = messages + [((None, (f"public/{video_url}",)))]
96
+ return messages
97
+
98
+ with gr.Blocks() as demo:
99
+ gr.Markdown("<h1><center>HuggingGPT</center></h1>")
100
+ gr.Markdown("<p align='center'><img src='https://i.ibb.co/qNH3Jym/logo.png' height='25' width='95'></p>")
101
+
102
+ gr.Markdown("<p align='center' style='font-size: 20px;'>A system to connect LLMs with ML community. See our <a href='https://github.com/microsoft/JARVIS'>Project</a> and <a href='http://arxiv.org/abs/2303.17580'>Paper</a>.</p>")
103
+ with gr.Row().style(equal_height=True):
104
+ with gr.Column(scale=0.85):
105
+ openai_api_key = gr.Textbox(
106
+ show_label=False,
107
+ placeholder="Set your OpenAI API key here and press Enter",
108
+ lines=1,
109
+ type="password",
110
+ )
111
+ with gr.Column(scale=0.15, min_width=0):
112
+ btn1 = gr.Button("Submit").style(full_height=True)
113
+
114
+ chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
115
+
116
+ with gr.Row().style(equal_height=True):
117
+ with gr.Column(scale=0.85):
118
+ txt = gr.Textbox(
119
+ show_label=False,
120
+ placeholder="Enter text and press enter. The url of the multimedia resource must contain the extension name.",
121
+ lines=1,
122
+ )
123
+ with gr.Column(scale=0.15, min_width=0):
124
+ btn2 = gr.Button("Send").style(full_height=True)
125
+
126
+ txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(
127
+ bot, chatbot, chatbot
128
+ )
129
+ openai_api_key.submit(set_openai_key, [openai_api_key], [openai_api_key])
130
+
131
+ btn1.click(set_openai_key, [openai_api_key], [openai_api_key])
132
+
133
+ btn2.click(add_text, [chatbot, txt], [chatbot, txt]).then(
134
+ bot, chatbot, chatbot
135
+ )
136
+
137
+ gr.Examples(
138
+ examples=["Given a collection of image A: /examples/a.jpg, B: /examples/b.jpg, C: /examples/c.jpg, please tell me how many zebras in these picture?",
139
+ "Please generate a canny image based on /examples/f.jpg",
140
+ "show me a joke and an image of cat",
141
+ "what is in the examples/a.jpg",
142
+ "generate a video and audio about a dog is running on the grass",
143
+ "based on the /examples/a.jpg, please generate a video and audio",
144
+ "based on pose of /examples/d.jpg and content of /examples/e.jpg, please show me a new image",
145
+ ],
146
+ inputs=txt
147
+ )
148
+
149
+ demo.launch()
awesome_chat.py ADDED
@@ -0,0 +1,896 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import copy
3
+ from io import BytesIO
4
+ import io
5
+ import os
6
+ import random
7
+ import time
8
+ import traceback
9
+ import uuid
10
+ import requests
11
+ import re
12
+ import json
13
+ import logging
14
+ import argparse
15
+ import yaml
16
+ from PIL import Image, ImageDraw
17
+ from diffusers.utils import load_image
18
+ from pydub import AudioSegment
19
+ import threading
20
+ from queue import Queue
21
+ import flask
22
+ from flask import request, jsonify
23
+ import waitress
24
+ from flask_cors import CORS
25
+ from get_token_ids import get_token_ids_for_task_parsing, get_token_ids_for_choose_model, count_tokens, get_max_context_length
26
+ from huggingface_hub.inference_api import InferenceApi
27
+ from huggingface_hub.inference_api import ALL_TASKS
28
+ from models_server import models, status
29
+ from functools import partial
30
+
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument("--config", type=str, default="config.yaml.dev")
33
+ parser.add_argument("--mode", type=str, default="cli")
34
+ args = parser.parse_args()
35
+
36
+ if __name__ != "__main__":
37
+ args.config = "config.gradio.yaml"
38
+
39
+ config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
40
+
41
+ if not os.path.exists("logs"):
42
+ os.mkdir("logs")
43
+
44
+ logger = logging.getLogger(__name__)
45
+ logger.setLevel(logging.DEBUG)
46
+
47
+ handler = logging.StreamHandler()
48
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
49
+ handler.setFormatter(formatter)
50
+ if not config["debug"]:
51
+ handler.setLevel(logging.INFO)
52
+ logger.addHandler(handler)
53
+
54
+ log_file = config["log_file"]
55
+ if log_file:
56
+ filehandler = logging.FileHandler(log_file)
57
+ filehandler.setLevel(logging.DEBUG)
58
+ filehandler.setFormatter(formatter)
59
+ logger.addHandler(filehandler)
60
+
61
+ LLM = config["model"]
62
+ use_completion = config["use_completion"]
63
+
64
+ # consistent: wrong msra model name
65
+ LLM_encoding = LLM
66
+ if LLM == "gpt-3.5-turbo":
67
+ LLM_encoding = "text-davinci-003"
68
+ task_parsing_highlight_ids = get_token_ids_for_task_parsing(LLM_encoding)
69
+ choose_model_highlight_ids = get_token_ids_for_choose_model(LLM_encoding)
70
+
71
+ # ENDPOINT MODEL NAME
72
+ # /v1/chat/completions gpt-4, gpt-4-0314, gpt-4-32k, gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301
73
+ # /v1/completions text-davinci-003, text-davinci-002, text-curie-001, text-babbage-001, text-ada-001, davinci, curie, babbage, ada
74
+
75
+ if use_completion:
76
+ api_name = "completions"
77
+ else:
78
+ api_name = "chat/completions"
79
+
80
+ if not config["dev"]:
81
+ if not config["openai"]["key"].startswith("sk-") and not config["openai"]["key"]=="gradio":
82
+ raise ValueError("Incrorrect OpenAI key. Please check your config.yaml file.")
83
+ OPENAI_KEY = config["openai"]["key"]
84
+ endpoint = f"https://api.openai.com/v1/{api_name}"
85
+ if OPENAI_KEY.startswith("sk-"):
86
+ HEADER = {
87
+ "Authorization": f"Bearer {OPENAI_KEY}"
88
+ }
89
+ else:
90
+ HEADER = None
91
+ else:
92
+ endpoint = f"{config['local']['endpoint']}/v1/{api_name}"
93
+ HEADER = None
94
+
95
+ PROXY = None
96
+ if config["proxy"]:
97
+ PROXY = {
98
+ "https": config["proxy"],
99
+ }
100
+
101
+ inference_mode = config["inference_mode"]
102
+
103
+
104
+ parse_task_demos_or_presteps = open(config["demos_or_presteps"]["parse_task"], "r").read()
105
+ choose_model_demos_or_presteps = open(config["demos_or_presteps"]["choose_model"], "r").read()
106
+ response_results_demos_or_presteps = open(config["demos_or_presteps"]["response_results"], "r").read()
107
+
108
+ parse_task_prompt = config["prompt"]["parse_task"]
109
+ choose_model_prompt = config["prompt"]["choose_model"]
110
+ response_results_prompt = config["prompt"]["response_results"]
111
+
112
+ parse_task_tprompt = config["tprompt"]["parse_task"]
113
+ choose_model_tprompt = config["tprompt"]["choose_model"]
114
+ response_results_tprompt = config["tprompt"]["response_results"]
115
+
116
+ MODELS = [json.loads(line) for line in open("data/p0_models.jsonl", "r").readlines()]
117
+ MODELS_MAP = {}
118
+ for model in MODELS:
119
+ tag = model["task"]
120
+ if tag not in MODELS_MAP:
121
+ MODELS_MAP[tag] = []
122
+ MODELS_MAP[tag].append(model)
123
+ METADATAS = {}
124
+ for model in MODELS:
125
+ METADATAS[model["id"]] = model
126
+
127
+ HUGGINGFACE_HEADERS = {}
128
+ if config["huggingface"]["token"]:
129
+ HUGGINGFACE_HEADERS = {
130
+ "Authorization": f"Bearer {config['huggingface']['token']}",
131
+ }
132
+
133
+ def convert_chat_to_completion(data):
134
+ messages = data.pop('messages', [])
135
+ tprompt = ""
136
+ if messages[0]['role'] == "system":
137
+ tprompt = messages[0]['content']
138
+ messages = messages[1:]
139
+ final_prompt = ""
140
+ for message in messages:
141
+ if message['role'] == "user":
142
+ final_prompt += ("<im_start>"+ "user" + "\n" + message['content'] + "<im_end>\n")
143
+ elif message['role'] == "assistant":
144
+ final_prompt += ("<im_start>"+ "assistant" + "\n" + message['content'] + "<im_end>\n")
145
+ else:
146
+ final_prompt += ("<im_start>"+ "system" + "\n" + message['content'] + "<im_end>\n")
147
+ final_prompt = tprompt + final_prompt
148
+ final_prompt = final_prompt + "<im_start>assistant"
149
+ data["prompt"] = final_prompt
150
+ data['stop'] = data.get('stop', ["<im_end>"])
151
+ data['max_tokens'] = data.get('max_tokens', max(get_max_context_length(LLM) - count_tokens(LLM_encoding, final_prompt), 1))
152
+ return data
153
+
154
+ def send_request(data):
155
+ global HEADER
156
+ openaikey = data.pop("openaikey")
157
+ if use_completion:
158
+ data = convert_chat_to_completion(data)
159
+ if openaikey and openaikey.startswith("sk-"):
160
+ HEADER = {
161
+ "Authorization": f"Bearer {openaikey}"
162
+ }
163
+
164
+ response = requests.post(endpoint, json=data, headers=HEADER, proxies=PROXY)
165
+ logger.debug(response.text.strip())
166
+ if use_completion:
167
+ return response.json()["choices"][0]["text"].strip()
168
+ else:
169
+ return response.json()["choices"][0]["message"]["content"].strip()
170
+
171
+ def replace_slot(text, entries):
172
+ for key, value in entries.items():
173
+ if not isinstance(value, str):
174
+ value = str(value)
175
+ text = text.replace("{{" + key +"}}", value.replace('"', "'").replace('\n', ""))
176
+ return text
177
+
178
+ def find_json(s):
179
+ s = s.replace("\'", "\"")
180
+ start = s.find("{")
181
+ end = s.rfind("}")
182
+ res = s[start:end+1]
183
+ res = res.replace("\n", "")
184
+ return res
185
+
186
+ def field_extract(s, field):
187
+ try:
188
+ field_rep = re.compile(f'{field}.*?:.*?"(.*?)"', re.IGNORECASE)
189
+ extracted = field_rep.search(s).group(1).replace("\"", "\'")
190
+ except:
191
+ field_rep = re.compile(f'{field}:\ *"(.*?)"', re.IGNORECASE)
192
+ extracted = field_rep.search(s).group(1).replace("\"", "\'")
193
+ return extracted
194
+
195
+ def get_id_reason(choose_str):
196
+ reason = field_extract(choose_str, "reason")
197
+ id = field_extract(choose_str, "id")
198
+ choose = {"id": id, "reason": reason}
199
+ return id.strip(), reason.strip(), choose
200
+
201
+ def record_case(success, **args):
202
+ if success:
203
+ f = open("logs/log_success.jsonl", "a")
204
+ else:
205
+ f = open("logs/log_fail.jsonl", "a")
206
+ log = args
207
+ f.write(json.dumps(log) + "\n")
208
+ f.close()
209
+
210
+ def image_to_bytes(img_url):
211
+ img_byte = io.BytesIO()
212
+ type = img_url.split(".")[-1]
213
+ load_image(img_url).save(img_byte, format="png")
214
+ img_data = img_byte.getvalue()
215
+ return img_data
216
+
217
+ def resource_has_dep(command):
218
+ args = command["args"]
219
+ for _, v in args.items():
220
+ if "<GENERATED>" in v:
221
+ return True
222
+ return False
223
+
224
+ def fix_dep(tasks):
225
+ for task in tasks:
226
+ args = task["args"]
227
+ task["dep"] = []
228
+ for k, v in args.items():
229
+ if "<GENERATED>" in v:
230
+ dep_task_id = int(v.split("-")[1])
231
+ if dep_task_id not in task["dep"]:
232
+ task["dep"].append(dep_task_id)
233
+ if len(task["dep"]) == 0:
234
+ task["dep"] = [-1]
235
+ return tasks
236
+
237
+ def unfold(tasks):
238
+ flag_unfold_task = False
239
+ try:
240
+ for task in tasks:
241
+ for key, value in task["args"].items():
242
+ if "<GENERATED>" in value:
243
+ generated_items = value.split(",")
244
+ if len(generated_items) > 1:
245
+ flag_unfold_task = True
246
+ for item in generated_items:
247
+ new_task = copy.deepcopy(task)
248
+ dep_task_id = int(item.split("-")[1])
249
+ new_task["dep"] = [dep_task_id]
250
+ new_task["args"][key] = item
251
+ tasks.append(new_task)
252
+ tasks.remove(task)
253
+ except Exception as e:
254
+ print(e)
255
+ traceback.print_exc()
256
+ logger.debug("unfold task failed.")
257
+
258
+ if flag_unfold_task:
259
+ logger.debug(f"unfold tasks: {tasks}")
260
+
261
+ return tasks
262
+
263
+ def chitchat(messages, openaikey=None):
264
+ data = {
265
+ "model": LLM,
266
+ "messages": messages,
267
+ "openaikey": openaikey
268
+ }
269
+ return send_request(data)
270
+
271
+ def parse_task(context, input, openaikey=None):
272
+ demos_or_presteps = parse_task_demos_or_presteps
273
+ messages = json.loads(demos_or_presteps)
274
+ messages.insert(0, {"role": "system", "content": parse_task_tprompt})
275
+
276
+ # cut chat logs
277
+ start = 0
278
+ while start <= len(context):
279
+ history = context[start:]
280
+ prompt = replace_slot(parse_task_prompt, {
281
+ "input": input,
282
+ "context": history
283
+ })
284
+ messages.append({"role": "user", "content": prompt})
285
+ history_text = "<im_end>\nuser<im_start>".join([m["content"] for m in messages])
286
+ num = count_tokens(LLM_encoding, history_text)
287
+ if get_max_context_length(LLM) - num > 800:
288
+ break
289
+ messages.pop()
290
+ start += 2
291
+
292
+ logger.debug(messages)
293
+ data = {
294
+ "model": LLM,
295
+ "messages": messages,
296
+ "temperature": 0,
297
+ "logit_bias": {item: config["logit_bias"]["parse_task"] for item in task_parsing_highlight_ids},
298
+ "openaikey": openaikey
299
+ }
300
+ return send_request(data)
301
+
302
+ def choose_model(input, task, metas, openaikey = None):
303
+ prompt = replace_slot(choose_model_prompt, {
304
+ "input": input,
305
+ "task": task,
306
+ "metas": metas,
307
+ })
308
+ demos_or_presteps = replace_slot(choose_model_demos_or_presteps, {
309
+ "input": input,
310
+ "task": task,
311
+ "metas": metas
312
+ })
313
+ messages = json.loads(demos_or_presteps)
314
+ messages.insert(0, {"role": "system", "content": choose_model_tprompt})
315
+ messages.append({"role": "user", "content": prompt})
316
+ logger.debug(messages)
317
+ data = {
318
+ "model": LLM,
319
+ "messages": messages,
320
+ "temperature": 0,
321
+ "logit_bias": {item: config["logit_bias"]["choose_model"] for item in choose_model_highlight_ids}, # 5
322
+ "openaikey": openaikey
323
+ }
324
+ return send_request(data)
325
+
326
+
327
+ def response_results(input, results, openaikey=None):
328
+ results = [v for k, v in sorted(results.items(), key=lambda item: item[0])]
329
+ prompt = replace_slot(response_results_prompt, {
330
+ "input": input,
331
+ })
332
+ demos_or_presteps = replace_slot(response_results_demos_or_presteps, {
333
+ "input": input,
334
+ "processes": results
335
+ })
336
+ messages = json.loads(demos_or_presteps)
337
+ messages.insert(0, {"role": "system", "content": response_results_tprompt})
338
+ messages.append({"role": "user", "content": prompt})
339
+ logger.debug(messages)
340
+ data = {
341
+ "model": LLM,
342
+ "messages": messages,
343
+ "temperature": 0,
344
+ "openaikey": openaikey
345
+ }
346
+ return send_request(data)
347
+
348
+ def huggingface_model_inference(model_id, data, task):
349
+ task_url = f"https://api-inference.huggingface.co/models/{model_id}" # InferenceApi does not yet support some tasks
350
+ inference = InferenceApi(repo_id=model_id, token=config["huggingface"]["token"])
351
+
352
+ # NLP tasks
353
+ if task == "question-answering":
354
+ inputs = {"question": data["text"], "context": (data["context"] if "context" in data else "" )}
355
+ result = inference(inputs)
356
+ if task == "sentence-similarity":
357
+ inputs = {"source_sentence": data["text1"], "target_sentence": data["text2"]}
358
+ result = inference(inputs)
359
+ if task in ["text-classification", "token-classification", "text2text-generation", "summarization", "translation", "conversational", "text-generation"]:
360
+ inputs = data["text"]
361
+ result = inference(inputs)
362
+
363
+ # CV tasks
364
+ if task == "visual-question-answering" or task == "document-question-answering":
365
+ img_url = data["image"]
366
+ text = data["text"]
367
+ img_data = image_to_bytes(img_url)
368
+ img_base64 = base64.b64encode(img_data).decode("utf-8")
369
+ json_data = {}
370
+ json_data["inputs"] = {}
371
+ json_data["inputs"]["question"] = text
372
+ json_data["inputs"]["image"] = img_base64
373
+ result = requests.post(task_url, headers=HUGGINGFACE_HEADERS, json=json_data).json()
374
+ # result = inference(inputs) # not support
375
+
376
+ if task == "image-to-image":
377
+ img_url = data["image"]
378
+ img_data = image_to_bytes(img_url)
379
+ # result = inference(data=img_data) # not support
380
+ HUGGINGFACE_HEADERS["Content-Length"] = str(len(img_data))
381
+ r = requests.post(task_url, headers=HUGGINGFACE_HEADERS, data=img_data)
382
+ result = r.json()
383
+ if "path" in result:
384
+ result["generated image"] = result.pop("path")
385
+
386
+ if task == "text-to-image":
387
+ inputs = data["text"]
388
+ img = inference(inputs)
389
+ name = str(uuid.uuid4())[:4]
390
+ img.save(f"public/images/{name}.png")
391
+ result = {}
392
+ result["generated image"] = f"/images/{name}.png"
393
+
394
+ if task == "image-segmentation":
395
+ img_url = data["image"]
396
+ img_data = image_to_bytes(img_url)
397
+ image = Image.open(BytesIO(img_data))
398
+ predicted = inference(data=img_data)
399
+ colors = []
400
+ for i in range(len(predicted)):
401
+ colors.append((random.randint(100, 255), random.randint(100, 255), random.randint(100, 255), 155))
402
+ for i, pred in enumerate(predicted):
403
+ label = pred["label"]
404
+ mask = pred.pop("mask").encode("utf-8")
405
+ mask = base64.b64decode(mask)
406
+ mask = Image.open(BytesIO(mask), mode='r')
407
+ mask = mask.convert('L')
408
+
409
+ layer = Image.new('RGBA', mask.size, colors[i])
410
+ image.paste(layer, (0, 0), mask)
411
+ name = str(uuid.uuid4())[:4]
412
+ image.save(f"public/images/{name}.jpg")
413
+ result = {}
414
+ result["generated image with segmentation mask"] = f"/images/{name}.jpg"
415
+ result["predicted"] = predicted
416
+
417
+ if task == "object-detection":
418
+ img_url = data["image"]
419
+ img_data = image_to_bytes(img_url)
420
+ predicted = inference(data=img_data)
421
+ image = Image.open(BytesIO(img_data))
422
+ draw = ImageDraw.Draw(image)
423
+ labels = list(item['label'] for item in predicted)
424
+ color_map = {}
425
+ for label in labels:
426
+ if label not in color_map:
427
+ color_map[label] = (random.randint(0, 255), random.randint(0, 100), random.randint(0, 255))
428
+ for label in predicted:
429
+ box = label["box"]
430
+ draw.rectangle(((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])), outline=color_map[label["label"]], width=2)
431
+ draw.text((box["xmin"]+5, box["ymin"]-15), label["label"], fill=color_map[label["label"]])
432
+ name = str(uuid.uuid4())[:4]
433
+ image.save(f"public/images/{name}.jpg")
434
+ result = {}
435
+ result["generated image with predicted box"] = f"/images/{name}.jpg"
436
+ result["predicted"] = predicted
437
+
438
+ if task in ["image-classification"]:
439
+ img_url = data["image"]
440
+ img_data = image_to_bytes(img_url)
441
+ result = inference(data=img_data)
442
+
443
+ if task == "image-to-text":
444
+ img_url = data["image"]
445
+ img_data = image_to_bytes(img_url)
446
+ HUGGINGFACE_HEADERS["Content-Length"] = str(len(img_data))
447
+ r = requests.post(task_url, headers=HUGGINGFACE_HEADERS, data=img_data)
448
+ result = {}
449
+ if "generated_text" in r.json()[0]:
450
+ result["generated text"] = r.json()[0].pop("generated_text")
451
+
452
+ # AUDIO tasks
453
+ if task == "text-to-speech":
454
+ inputs = data["text"]
455
+ response = inference(inputs, raw_response=True)
456
+ # response = requests.post(task_url, headers=HUGGINGFACE_HEADERS, json={"inputs": text})
457
+ name = str(uuid.uuid4())[:4]
458
+ with open(f"public/audios/{name}.flac", "wb") as f:
459
+ f.write(response.content)
460
+ result = {"generated audio": f"/audios/{name}.flac"}
461
+ if task in ["automatic-speech-recognition", "audio-to-audio", "audio-classification"]:
462
+ audio_url = data["audio"]
463
+ audio_data = requests.get(audio_url, timeout=10).content
464
+ response = inference(data=audio_data, raw_response=True)
465
+ result = response.json()
466
+ if task == "audio-to-audio":
467
+ content = None
468
+ type = None
469
+ for k, v in result[0].items():
470
+ if k == "blob":
471
+ content = base64.b64decode(v.encode("utf-8"))
472
+ if k == "content-type":
473
+ type = "audio/flac".split("/")[-1]
474
+ audio = AudioSegment.from_file(BytesIO(content))
475
+ name = str(uuid.uuid4())[:4]
476
+ audio.export(f"public/audios/{name}.{type}", format=type)
477
+ result = {"generated audio": f"/audios/{name}.{type}"}
478
+ return result
479
+
480
+ def local_model_inference(model_id, data, task):
481
+ inference = partial(models, model_id)
482
+ # contronlet
483
+ if model_id.startswith("lllyasviel/sd-controlnet-"):
484
+ img_url = data["image"]
485
+ text = data["text"]
486
+ results = inference({"img_url": img_url, "text": text})
487
+ if "path" in results:
488
+ results["generated image"] = results.pop("path")
489
+ return results
490
+ if model_id.endswith("-control"):
491
+ img_url = data["image"]
492
+ results = inference({"img_url": img_url})
493
+ if "path" in results:
494
+ results["generated image"] = results.pop("path")
495
+ return results
496
+
497
+ if task == "text-to-video":
498
+ results = inference(data)
499
+ if "path" in results:
500
+ results["generated video"] = results.pop("path")
501
+ return results
502
+
503
+ # NLP tasks
504
+ if task == "question-answering" or task == "sentence-similarity":
505
+ results = inference(json=data)
506
+ return results
507
+ if task in ["text-classification", "token-classification", "text2text-generation", "summarization", "translation", "conversational", "text-generation"]:
508
+ results = inference(json=data)
509
+ return results
510
+
511
+ # CV tasks
512
+ if task == "depth-estimation":
513
+ img_url = data["image"]
514
+ results = inference({"img_url": img_url})
515
+ if "path" in results:
516
+ results["generated depth image"] = results.pop("path")
517
+ return results
518
+ if task == "image-segmentation":
519
+ img_url = data["image"]
520
+ results = inference({"img_url": img_url})
521
+ results["generated image with segmentation mask"] = results.pop("path")
522
+ return results
523
+ if task == "image-to-image":
524
+ img_url = data["image"]
525
+ results = inference({"img_url": img_url})
526
+ if "path" in results:
527
+ results["generated image"] = results.pop("path")
528
+ return results
529
+ if task == "text-to-image":
530
+ results = inference(data)
531
+ if "path" in results:
532
+ results["generated image"] = results.pop("path")
533
+ return results
534
+ if task == "object-detection":
535
+ img_url = data["image"]
536
+ predicted = inference({"img_url": img_url})
537
+ if "error" in predicted:
538
+ return predicted
539
+ image = load_image(img_url)
540
+ draw = ImageDraw.Draw(image)
541
+ labels = list(item['label'] for item in predicted)
542
+ color_map = {}
543
+ for label in labels:
544
+ if label not in color_map:
545
+ color_map[label] = (random.randint(0, 255), random.randint(0, 100), random.randint(0, 255))
546
+ for label in predicted:
547
+ box = label["box"]
548
+ draw.rectangle(((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])), outline=color_map[label["label"]], width=2)
549
+ draw.text((box["xmin"]+5, box["ymin"]-15), label["label"], fill=color_map[label["label"]])
550
+ name = str(uuid.uuid4())[:4]
551
+ image.save(f"public/images/{name}.jpg")
552
+ results = {}
553
+ results["generated image with predicted box"] = f"/images/{name}.jpg"
554
+ results["predicted"] = predicted
555
+ return results
556
+ if task in ["image-classification", "image-to-text", "document-question-answering", "visual-question-answering"]:
557
+ img_url = data["image"]
558
+ text = None
559
+ if "text" in data:
560
+ text = data["text"]
561
+ results = inference({"img_url": img_url, "text": text})
562
+ return results
563
+ # AUDIO tasks
564
+ if task == "text-to-speech":
565
+ results = inference(data)
566
+ if "path" in results:
567
+ results["generated audio"] = results.pop("path")
568
+ return results
569
+ if task in ["automatic-speech-recognition", "audio-to-audio", "audio-classification"]:
570
+ audio_url = data["audio"]
571
+ results = inference({"audio_url": audio_url})
572
+ return results
573
+
574
+
575
+ def model_inference(model_id, data, hosted_on, task):
576
+ if hosted_on == "unknown":
577
+ r = status(model_id)
578
+ logger.debug("Local Server Status: " + str(r.json()))
579
+ if r.status_code == 200 and "loaded" in r.json() and r.json()["loaded"]:
580
+ hosted_on = "local"
581
+ else:
582
+ huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}"
583
+ r = requests.get(huggingfaceStatusUrl, headers=HUGGINGFACE_HEADERS, proxies=PROXY)
584
+ logger.debug("Huggingface Status: " + str(r.json()))
585
+ if r.status_code == 200 and "loaded" in r.json() and r.json()["loaded"]:
586
+ hosted_on = "huggingface"
587
+ try:
588
+ if hosted_on == "local":
589
+ inference_result = local_model_inference(model_id, data, task)
590
+ elif hosted_on == "huggingface":
591
+ inference_result = huggingface_model_inference(model_id, data, task)
592
+ except Exception as e:
593
+ print(e)
594
+ traceback.print_exc()
595
+ inference_result = {"error":{"message": str(e)}}
596
+ return inference_result
597
+
598
+
599
+ def get_model_status(model_id, url, headers, queue = None):
600
+ endpoint_type = "huggingface" if "huggingface" in url else "local"
601
+ if "huggingface" in url:
602
+ r = requests.get(url, headers=headers, proxies=PROXY)
603
+ else:
604
+ r = status(model_id)
605
+ if r.status_code == 200 and "loaded" in r.json() and r.json()["loaded"]:
606
+ if queue:
607
+ queue.put((model_id, True, endpoint_type))
608
+ return True
609
+ else:
610
+ if queue:
611
+ queue.put((model_id, False, None))
612
+ return False
613
+
614
+ def get_avaliable_models(candidates, topk=5):
615
+ all_available_models = {"local": [], "huggingface": []}
616
+ threads = []
617
+ result_queue = Queue()
618
+
619
+ for candidate in candidates:
620
+ model_id = candidate["id"]
621
+
622
+ if inference_mode != "local":
623
+ huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}"
624
+ thread = threading.Thread(target=get_model_status, args=(model_id, huggingfaceStatusUrl, HUGGINGFACE_HEADERS, result_queue))
625
+ threads.append(thread)
626
+ thread.start()
627
+
628
+ if inference_mode != "huggingface" and config["local_deployment"] != "minimal":
629
+ thread = threading.Thread(target=get_model_status, args=(model_id, "", {}, result_queue))
630
+ threads.append(thread)
631
+ thread.start()
632
+
633
+ result_count = len(threads)
634
+ while result_count:
635
+ model_id, status, endpoint_type = result_queue.get()
636
+ if status and model_id not in all_available_models:
637
+ all_available_models[endpoint_type].append(model_id)
638
+ if len(all_available_models["local"] + all_available_models["huggingface"]) >= topk:
639
+ break
640
+ result_count -= 1
641
+
642
+ for thread in threads:
643
+ thread.join()
644
+
645
+ return all_available_models
646
+
647
+ def collect_result(command, choose, inference_result):
648
+ result = {"task": command}
649
+ result["inference result"] = inference_result
650
+ result["choose model result"] = choose
651
+ logger.debug(f"inference result: {inference_result}")
652
+ return result
653
+
654
+
655
+ def run_task(input, command, results, openaikey = None):
656
+ id = command["id"]
657
+ args = command["args"]
658
+ task = command["task"]
659
+ deps = command["dep"]
660
+ if deps[0] != -1:
661
+ dep_tasks = [results[dep] for dep in deps]
662
+ else:
663
+ dep_tasks = []
664
+
665
+ logger.debug(f"Run task: {id} - {task}")
666
+ logger.debug("Deps: " + json.dumps(dep_tasks))
667
+
668
+ if deps[0] != -1:
669
+ if "image" in args and "<GENERATED>-" in args["image"]:
670
+ resource_id = int(args["image"].split("-")[1])
671
+ if "generated image" in results[resource_id]["inference result"]:
672
+ args["image"] = results[resource_id]["inference result"]["generated image"]
673
+ if "audio" in args and "<GENERATED>-" in args["audio"]:
674
+ resource_id = int(args["audio"].split("-")[1])
675
+ if "generated audio" in results[resource_id]["inference result"]:
676
+ args["audio"] = results[resource_id]["inference result"]["generated audio"]
677
+ if "text" in args and "<GENERATED>-" in args["text"]:
678
+ resource_id = int(args["text"].split("-")[1])
679
+ if "generated text" in results[resource_id]["inference result"]:
680
+ args["text"] = results[resource_id]["inference result"]["generated text"]
681
+
682
+ text = image = audio = None
683
+ for dep_task in dep_tasks:
684
+ if "generated text" in dep_task["inference result"]:
685
+ text = dep_task["inference result"]["generated text"]
686
+ logger.debug("Detect the generated text of dependency task (from results):" + text)
687
+ elif "text" in dep_task["task"]["args"]:
688
+ text = dep_task["task"]["args"]["text"]
689
+ logger.debug("Detect the text of dependency task (from args): " + text)
690
+ if "generated image" in dep_task["inference result"]:
691
+ image = dep_task["inference result"]["generated image"]
692
+ logger.debug("Detect the generated image of dependency task (from results): " + image)
693
+ elif "image" in dep_task["task"]["args"]:
694
+ image = dep_task["task"]["args"]["image"]
695
+ logger.debug("Detect the image of dependency task (from args): " + image)
696
+ if "generated audio" in dep_task["inference result"]:
697
+ audio = dep_task["inference result"]["generated audio"]
698
+ logger.debug("Detect the generated audio of dependency task (from results): " + audio)
699
+ elif "audio" in dep_task["task"]["args"]:
700
+ audio = dep_task["task"]["args"]["audio"]
701
+ logger.debug("Detect the audio of dependency task (from args): " + audio)
702
+
703
+ if "image" in args and "<GENERATED>" in args["image"]:
704
+ if image:
705
+ args["image"] = image
706
+ if "audio" in args and "<GENERATED>" in args["audio"]:
707
+ if audio:
708
+ args["audio"] = audio
709
+ if "text" in args and "<GENERATED>" in args["text"]:
710
+ if text:
711
+ args["text"] = text
712
+
713
+ for resource in ["image", "audio"]:
714
+ if resource in args and not args[resource].startswith("public/") and len(args[resource]) > 0 and not args[resource].startswith("http"):
715
+ args[resource] = f"public/{args[resource]}"
716
+
717
+ if "-text-to-image" in command['task'] and "text" not in args:
718
+ logger.debug("control-text-to-image task, but text is empty, so we use control-generation instead.")
719
+ control = task.split("-")[0]
720
+
721
+ if control == "seg":
722
+ task = "image-segmentation"
723
+ command['task'] = task
724
+ elif control == "depth":
725
+ task = "depth-estimation"
726
+ command['task'] = task
727
+ else:
728
+ task = f"{control}-control"
729
+
730
+ command["args"] = args
731
+ logger.debug(f"parsed task: {command}")
732
+
733
+ if task.endswith("-text-to-image") or task.endswith("-control"):
734
+ if inference_mode != "huggingface":
735
+ if task.endswith("-text-to-image"):
736
+ control = task.split("-")[0]
737
+ best_model_id = f"lllyasviel/sd-controlnet-{control}"
738
+ else:
739
+ best_model_id = task
740
+ hosted_on = "local"
741
+ reason = "ControlNet is the best model for this task."
742
+ choose = {"id": best_model_id, "reason": reason}
743
+ logger.debug(f"chosen model: {choose}")
744
+ else:
745
+ logger.warning(f"Task {command['task']} is not available. ControlNet need to be deployed locally.")
746
+ record_case(success=False, **{"input": input, "task": command, "reason": f"Task {command['task']} is not available. ControlNet need to be deployed locally.", "op":"message"})
747
+ inference_result = {"error": f"service related to ControlNet is not available."}
748
+ results[id] = collect_result(command, "", inference_result)
749
+ return False
750
+ elif task in ["summarization", "translation", "conversational", "text-generation", "text2text-generation"]: # ChatGPT Can do
751
+ best_model_id = "ChatGPT"
752
+ reason = "ChatGPT is the best model for this task."
753
+ choose = {"id": best_model_id, "reason": reason}
754
+ messages = [{
755
+ "role": "user",
756
+ "content": f"[ {input} ] contains a task in JSON format {command}, 'task' indicates the task type and 'args' indicates the arguments required for the task. Don't explain the task to me, just help me do it and give me the result. The result must be in text form without any urls."
757
+ }]
758
+ response = chitchat(messages, openaikey)
759
+ results[id] = collect_result(command, choose, {"response": response})
760
+ return True
761
+ else:
762
+ if task not in MODELS_MAP:
763
+ logger.warning(f"no available models on {task} task.")
764
+ record_case(success=False, **{"input": input, "task": command, "reason": f"task not support: {command['task']}", "op":"message"})
765
+ inference_result = {"error": f"{command['task']} not found in available tasks."}
766
+ results[id] = collect_result(command, choose, inference_result)
767
+ return False
768
+
769
+ candidates = MODELS_MAP[task][:10]
770
+ all_avaliable_models = get_avaliable_models(candidates, config["num_candidate_models"])
771
+ all_avaliable_model_ids = all_avaliable_models["local"] + all_avaliable_models["huggingface"]
772
+ logger.debug(f"avaliable models on {command['task']}: {all_avaliable_models}")
773
+
774
+ if len(all_avaliable_model_ids) == 0:
775
+ logger.warning(f"no available models on {command['task']}")
776
+ record_case(success=False, **{"input": input, "task": command, "reason": f"no available models: {command['task']}", "op":"message"})
777
+ inference_result = {"error": f"no available models on {command['task']} task."}
778
+ results[id] = collect_result(command, "", inference_result)
779
+ return False
780
+
781
+ if len(all_avaliable_model_ids) == 1:
782
+ best_model_id = all_avaliable_model_ids[0]
783
+ hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
784
+ reason = "Only one model available."
785
+ choose = {"id": best_model_id, "reason": reason}
786
+ logger.debug(f"chosen model: {choose}")
787
+ else:
788
+ cand_models_info = [
789
+ {
790
+ "id": model["id"],
791
+ "inference endpoint": all_avaliable_models.get(
792
+ "local" if model["id"] in all_avaliable_models["local"] else "huggingface"
793
+ ),
794
+ "likes": model.get("likes"),
795
+ "description": model.get("description", "")[:config["max_description_length"]],
796
+ "language": model.get("language"),
797
+ "tags": model.get("tags"),
798
+ }
799
+ for model in candidates
800
+ if model["id"] in all_avaliable_model_ids
801
+ ]
802
+
803
+ choose_str = choose_model(input, command, cand_models_info, openaikey)
804
+ logger.debug(f"chosen model: {choose_str}")
805
+ try:
806
+ choose = json.loads(choose_str)
807
+ reason = choose["reason"]
808
+ best_model_id = choose["id"]
809
+ hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
810
+ except Exception as e:
811
+ logger.warning(f"the response [ {choose_str} ] is not a valid JSON, try to find the model id and reason in the response.")
812
+ choose_str = find_json(choose_str)
813
+ best_model_id, reason, choose = get_id_reason(choose_str)
814
+ hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
815
+ inference_result = model_inference(best_model_id, args, hosted_on, command['task'])
816
+
817
+ if "error" in inference_result:
818
+ logger.warning(f"Inference error: {inference_result['error']}")
819
+ record_case(success=False, **{"input": input, "task": command, "reason": f"inference error: {inference_result['error']}", "op":"message"})
820
+ results[id] = collect_result(command, choose, inference_result)
821
+ return False
822
+
823
+ results[id] = collect_result(command, choose, inference_result)
824
+ return True
825
+
826
+ def chat_huggingface(messages, openaikey = None, return_planning = False, return_results = False):
827
+ start = time.time()
828
+ context = messages[:-1]
829
+ input = messages[-1]["content"]
830
+ logger.info("*"*80)
831
+ logger.info(f"input: {input}")
832
+
833
+ task_str = parse_task(context, input, openaikey).strip()
834
+ logger.info(task_str)
835
+
836
+ if task_str == "[]": # using LLM response for empty task
837
+ record_case(success=False, **{"input": input, "task": [], "reason": "task parsing fail: empty", "op": "chitchat"})
838
+ response = chitchat(messages, openaikey)
839
+ return {"message": response}
840
+ try:
841
+ tasks = json.loads(task_str)
842
+ except Exception as e:
843
+ logger.debug(e)
844
+ response = chitchat(messages, openaikey)
845
+ record_case(success=False, **{"input": input, "task": task_str, "reason": "task parsing fail", "op":"chitchat"})
846
+ return {"message": response}
847
+
848
+
849
+ tasks = unfold(tasks)
850
+ tasks = fix_dep(tasks)
851
+ logger.debug(tasks)
852
+
853
+ if return_planning:
854
+ return tasks
855
+
856
+ results = {}
857
+ threads = []
858
+ tasks = tasks[:]
859
+ d = dict()
860
+ retry = 0
861
+ while True:
862
+ num_threads = len(threads)
863
+ for task in tasks:
864
+ dep = task["dep"]
865
+ # logger.debug(f"d.keys(): {d.keys()}, dep: {dep}")
866
+ if len(list(set(dep).intersection(d.keys()))) == len(dep) or dep[0] == -1:
867
+ tasks.remove(task)
868
+ thread = threading.Thread(target=run_task, args=(input, task, d, openaikey))
869
+ thread.start()
870
+ threads.append(thread)
871
+ if num_threads == len(threads):
872
+ time.sleep(0.5)
873
+ retry += 1
874
+ if retry > 160:
875
+ logger.debug("User has waited too long, Loop break.")
876
+ break
877
+ if len(tasks) == 0:
878
+ break
879
+ for thread in threads:
880
+ thread.join()
881
+
882
+ results = d.copy()
883
+
884
+ logger.debug(results)
885
+ if return_results:
886
+ return results
887
+
888
+ response = response_results(input, results, openaikey).strip()
889
+
890
+ end = time.time()
891
+ during = end - start
892
+
893
+ answer = {"message": response}
894
+ record_case(success=True, **{"input": input, "task": task_str, "results": results, "response": response, "during": during, "op":"response"})
895
+ logger.info(f"response: {response}")
896
+ return answer
config.gradio.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai:
2
+ key: gradio # "gradio" (set when request) or your_personal_key
3
+ huggingface:
4
+ token: # required: huggingface token @ https://huggingface.co/settings/tokens
5
+ local: # ignore: just for development
6
+ endpoint: http://localhost:8003
7
+ dev: false
8
+ debug: false
9
+ log_file: logs/debug.log
10
+ model: text-davinci-003 # text-davinci-003
11
+ use_completion: true
12
+ inference_mode: hybrid # local, huggingface or hybrid
13
+ local_deployment: minimal # minimal, standard or full
14
+ num_candidate_models: 5
15
+ max_description_length: 100
16
+ proxy:
17
+ httpserver:
18
+ host: localhost
19
+ port: 8004
20
+ modelserver:
21
+ host: localhost
22
+ port: 8005
23
+ logit_bias:
24
+ parse_task: 0.1
25
+ choose_model: 5
26
+ tprompt:
27
+ parse_task: >-
28
+ #1 Task Planning Stage: The AI assistant can parse user input to several tasks: [{"task": task, "id": task_id, "dep": dependency_task_id, "args": {"text": text or <GENERATED>-dep_id, "image": image_url or <GENERATED>-dep_id, "audio": audio_url or <GENERATED>-dep_id}}]. The special tag "<GENERATED>-dep_id" refer to the one genereted text/image/audio in the dependency task (Please consider whether the dependency task generates resources of this type.) and "dep_id" must be in "dep" list. The "dep" field denotes the ids of the previous prerequisite tasks which generate a new resource that the current task relies on. The "args" field must in ["text", "image", "audio"], nothing else. The task MUST be selected from the following options: "token-classification", "text2text-generation", "summarization", "translation", "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "text-to-video", "visual-question-answering", "document-question-answering", "image-segmentation", "depth-estimation", "text-to-speech", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "canny-control", "hed-control", "mlsd-control", "normal-control", "openpose-control", "canny-text-to-image", "depth-text-to-image", "hed-text-to-image", "mlsd-text-to-image", "normal-text-to-image", "openpose-text-to-image", "seg-text-to-image". There may be multiple tasks of the same type. Think step by step about all the tasks needed to resolve the user's request. Parse out as few tasks as possible while ensuring that the user request can be resolved. Pay attention to the dependencies and order among tasks. If the user input can't be parsed, you need to reply empty JSON [].
29
+ choose_model: >-
30
+ #2 Model Selection Stage: Given the user request and the parsed tasks, the AI assistant helps the user to select a suitable model from a list of models to process the user request. The assistant should focus more on the description of the model and find the model that has the most potential to solve requests and tasks. Also, prefer models with local inference endpoints for speed and stability.
31
+ response_results: >-
32
+ #4 Response Generation Stage: With the task execution logs, the AI assistant needs to describe the process and inference results.
33
+ demos_or_presteps:
34
+ parse_task: demos/demo_parse_task.json
35
+ choose_model: demos/demo_choose_model.json
36
+ response_results: demos/demo_response_results.json
37
+ prompt:
38
+ parse_task: The chat log [ {{context}} ] may contain the resources I mentioned. Now I input { {{input}} }. Pay attention to the input and output types of tasks and the dependencies between tasks.
39
+ choose_model: >-
40
+ Please choose the most suitable model from {{metas}} for the task {{task}}. The output must be in a strict JSON format: {"id": "id", "reason": "your detail reasons for the choice"}.
41
+ response_results: >-
42
+ Yes. Please first think carefully and directly answer my request based on the inference results. Then please detail your workflow step by step including the used models and inference results for my request in your friendly tone. Please filter out information that is not relevant to my request. If any generated files of images, audios or videos in the inference results, must tell me the complete path. If there is nothing in the results, please tell me you can't make it. }
data/p0_models.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
demos/demo_choose_model.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "user",
4
+ "content": "{{input}}"
5
+ },
6
+ {
7
+ "role": "assistant",
8
+ "content": "{{task}}"
9
+ }
10
+ ]
demos/demo_parse_task.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "user",
4
+ "content": "Give you some pictures e1.jpg, e2.png, e3.jpg, help me count the number of sheep?"
5
+ },
6
+ {
7
+ "role": "assistant",
8
+ "content": "[{\"task\": \"image-to-text\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"e1.jpg\" }}, {\"task\": \"object-detection\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"e1.jpg\" }}, {\"task\": \"visual-question-answering\", \"id\": 2, \"dep\": [1], \"args\": {\"image\": \"<GENERATED>-1\", \"text\": \"How many sheep in the picture\"}} }}, {\"task\": \"image-to-text\", \"id\": 3, \"dep\": [-1], \"args\": {\"image\": \"e2.png\" }}, {\"task\": \"object-detection\", \"id\": 4, \"dep\": [-1], \"args\": {\"image\": \"e2.png\" }}, {\"task\": \"visual-question-answering\", \"id\": 5, \"dep\": [4], \"args\": {\"image\": \"<GENERATED>-4\", \"text\": \"How many sheep in the picture\"}} }}, {\"task\": \"image-to-text\", \"id\": 6, \"dep\": [-1], \"args\": {\"image\": \"e3.jpg\" }}, {\"task\": \"object-detection\", \"id\": 7, \"dep\": [-1], \"args\": {\"image\": \"e3.jpg\" }}, {\"task\": \"visual-question-answering\", \"id\": 8, \"dep\": [7], \"args\": {\"image\": \"<GENERATED>-7\", \"text\": \"How many sheep in the picture\"}}]"
9
+ },
10
+
11
+ {
12
+ "role":"user",
13
+ "content":"Look at /e.jpg, can you tell me how many objects in the picture? Give me a picture and video similar to this one."
14
+ },
15
+ {
16
+ "role":"assistant",
17
+ "content":"[{\"task\": \"image-to-text\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/e.jpg\" }}, {\"task\": \"object-detection\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"/e.jpg\" }}, {\"task\": \"visual-question-answering\", \"id\": 2, \"dep\": [1], \"args\": {\"image\": \"<GENERATED>-1\", \"text\": \"how many objects in the picture?\" }}, {\"task\": \"text-to-image\", \"id\": 3, \"dep\": [0], \"args\": {\"text\": \"<GENERATED-0>\" }}, {\"task\": \"image-to-image\", \"id\": 4, \"dep\": [-1], \"args\": {\"image\": \"/e.jpg\" }}, {\"task\": \"text-to-video\", \"id\": 5, \"dep\": [0], \"args\": {\"text\": \"<GENERATED-0>\" }}]"
18
+ },
19
+
20
+ {
21
+ "role":"user",
22
+ "content":"given a document /images/e.jpeg, answer me what is the student amount? And describe the image with your voice"
23
+ },
24
+ {
25
+ "role":"assistant",
26
+ "content":"{\"task\": \"document-question-answering\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/images/e.jpeg\", \"text\": \"what is the student amount?\" }}, {\"task\": \"visual-question-answering\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"/images/e.jpeg\", \"text\": \"what is the student amount?\" }}, {\"task\": \"image-to-text\", \"id\": 2, \"dep\": [-1], \"args\": {\"image\": \"/images/e.jpg\" }}, {\"task\": \"text-to-speech\", \"id\": 3, \"dep\": [2], \"args\": {\"text\": \"<GENERATED>-2\" }}]"
27
+ },
28
+
29
+ {
30
+ "role": "user",
31
+ "content": "Given an image /example.jpg, first generate a hed image, then based on the hed image generate a new image where a girl is reading a book"
32
+ },
33
+ {
34
+ "role": "assistant",
35
+ "content": "[{\"task\": \"openpose-control\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/example.jpg\" }}, {\"task\": \"openpose-text-to-image\", \"id\": 1, \"dep\": [0], \"args\": {\"text\": \"a girl is reading a book\", \"image\": \"<GENERATED>-0\" }}]"
36
+ },
37
+
38
+ {
39
+ "role": "user",
40
+ "content": "please show me a video and an image of (based on the text) 'a boy is running' and dub it"
41
+ },
42
+ {
43
+ "role": "assistant",
44
+ "content": "[{\"task\": \"text-to-video\", \"id\": 0, \"dep\": [-1], \"args\": {\"text\": \"a boy is running\" }}, {\"task\": \"text-to-speech\", \"id\": 1, \"dep\": [-1], \"args\": {\"text\": \"a boy is running\" }}, {\"task\": \"text-to-image\", \"id\": 2, \"dep\": [-1], \"args\": {\"text\": \"a boy is running\" }}]"
45
+ },
46
+
47
+
48
+ {
49
+ "role": "user",
50
+ "content": "please show me a joke and an image of cat"
51
+ },
52
+ {
53
+ "role": "assistant",
54
+ "content": "[{\"task\": \"conversational\", \"id\": 0, \"dep\": [-1], \"args\": {\"text\": \"please show me a joke of cat\" }}, {\"task\": \"text-to-image\", \"id\": 1, \"dep\": [-1], \"args\": {\"text\": \"a photo of cat\" }}]"
55
+ }
56
+ ]
demos/demo_response_results.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "user",
4
+ "content": "{{input}}"
5
+ },
6
+ {
7
+ "role": "assistant",
8
+ "content": "Before give you a response, I want to introduce my workflow for your request, which is shown in the following JSON data: {{processes}}. Do you have any demands regarding my response?"
9
+ }
10
+ ]
get_token_ids.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+
3
+ encodings = {
4
+ "gpt-3.5-turbo": tiktoken.get_encoding("cl100k_base"),
5
+ "gpt-3.5-turbo-0301": tiktoken.get_encoding("cl100k_base"),
6
+ "text-davinci-003": tiktoken.get_encoding("p50k_base"),
7
+ "text-davinci-002": tiktoken.get_encoding("p50k_base"),
8
+ "text-davinci-001": tiktoken.get_encoding("r50k_base"),
9
+ "text-curie-001": tiktoken.get_encoding("r50k_base"),
10
+ "text-babbage-001": tiktoken.get_encoding("r50k_base"),
11
+ "text-ada-001": tiktoken.get_encoding("r50k_base"),
12
+ "davinci": tiktoken.get_encoding("r50k_base"),
13
+ "curie": tiktoken.get_encoding("r50k_base"),
14
+ "babbage": tiktoken.get_encoding("r50k_base"),
15
+ "ada": tiktoken.get_encoding("r50k_base"),
16
+ }
17
+
18
+ max_length = {
19
+ "gpt-3.5-turbo": 4096,
20
+ "gpt-3.5-turbo-0301": 4096,
21
+ "text-davinci-003": 4096,
22
+ "text-davinci-002": 4096,
23
+ "text-davinci-001": 2049,
24
+ "text-curie-001": 2049,
25
+ "text-babbage-001": 2049,
26
+ "text-ada-001": 2049,
27
+ "davinci": 2049,
28
+ "curie": 2049,
29
+ "babbage": 2049,
30
+ "ada": 2049
31
+ }
32
+
33
+ def count_tokens(model_name, text):
34
+ return len(encodings[model_name].encode(text))
35
+
36
+ def get_max_context_length(model_name):
37
+ return max_length[model_name]
38
+
39
+ def get_token_ids_for_task_parsing(model_name):
40
+ text = '''{"task": "text-classification", "token-classification", "text2text-generation", "summarization", "translation", "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "visual-question-answering", "document-question-answering", "image-segmentation", "text-to-speech", "text-to-video", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "canny-control", "hed-control", "mlsd-control", "normal-control", "openpose-control", "canny-text-to-image", "depth-text-to-image", "hed-text-to-image", "mlsd-text-to-image", "normal-text-to-image", "openpose-text-to-image", "seg-text-to-image", "args", "text", "path", "dep", "id", "<GENERATED>-"}'''
41
+ res = encodings[model_name].encode(text)
42
+ res = list(set(res))
43
+ return res
44
+
45
+ def get_token_ids_for_choose_model(model_name):
46
+ text = '''{"id": "reason"}'''
47
+ res = encodings[model_name].encode(text)
48
+ res = list(set(res))
49
+ return res
models_server.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import random
4
+ import uuid
5
+ import numpy as np
6
+ from transformers import pipeline
7
+ from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
8
+ from diffusers.utils import load_image
9
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
10
+ from diffusers.utils import export_to_video
11
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5ForSpeechToSpeech
12
+ from transformers import BlipProcessor, BlipForConditionalGeneration
13
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
14
+ from datasets import load_dataset
15
+ from PIL import Image
16
+ import io
17
+ from torchvision import transforms
18
+ import torch
19
+ import torchaudio
20
+ from speechbrain.pretrained import WaveformEnhancement
21
+ import joblib
22
+ from huggingface_hub import hf_hub_url, cached_download
23
+ from transformers import AutoImageProcessor, TimesformerForVideoClassification
24
+ from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation, AutoFeatureExtractor
25
+ from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector, CannyDetector, MidasDetector
26
+ from controlnet_aux.open_pose.body import Body
27
+ from controlnet_aux.mlsd.models.mbv2_mlsd_large import MobileV2_MLSD_Large
28
+ from controlnet_aux.hed import Network
29
+ from transformers import DPTForDepthEstimation, DPTFeatureExtractor
30
+ import warnings
31
+ import time
32
+ from espnet2.bin.tts_inference import Text2Speech
33
+ import soundfile as sf
34
+ from asteroid.models import BaseModel
35
+ import traceback
36
+ import os
37
+ import yaml
38
+
39
+ warnings.filterwarnings("ignore")
40
+
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--config", type=str, default="config.yaml")
43
+ args = parser.parse_args()
44
+
45
+ if __name__ != "__main__":
46
+ args.config = "config.gradio.yaml"
47
+
48
+ logger = logging.getLogger(__name__)
49
+ logger.setLevel(logging.INFO)
50
+ handler = logging.StreamHandler()
51
+ handler.setLevel(logging.INFO)
52
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
53
+ handler.setFormatter(formatter)
54
+ logger.addHandler(handler)
55
+
56
+ config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
57
+
58
+ local_deployment = config["local_deployment"]
59
+ if config["inference_mode"] == "huggingface":
60
+ local_deployment = "none"
61
+
62
+ PROXY = None
63
+ if config["proxy"]:
64
+ PROXY = {
65
+ "https": config["proxy"],
66
+ }
67
+
68
+ start = time.time()
69
+
70
+ local_models = ""
71
+
72
+ def load_pipes(local_deployment):
73
+ other_pipes = {}
74
+ standard_pipes = {}
75
+ controlnet_sd_pipes = {}
76
+ if local_deployment in ["full"]:
77
+ other_pipes = {
78
+ "nlpconnect/vit-gpt2-image-captioning":{
79
+ "model": VisionEncoderDecoderModel.from_pretrained(f"nlpconnect/vit-gpt2-image-captioning"),
80
+ "feature_extractor": ViTImageProcessor.from_pretrained(f"nlpconnect/vit-gpt2-image-captioning"),
81
+ "tokenizer": AutoTokenizer.from_pretrained(f"nlpconnect/vit-gpt2-image-captioning"),
82
+ "device": "cuda:0"
83
+ },
84
+ # "Salesforce/blip-image-captioning-large": {
85
+ # "model": BlipForConditionalGeneration.from_pretrained(f"Salesforce/blip-image-captioning-large"),
86
+ # "processor": BlipProcessor.from_pretrained(f"Salesforce/blip-image-captioning-large"),
87
+ # "device": "cuda:0"
88
+ # },
89
+ "damo-vilab/text-to-video-ms-1.7b": {
90
+ "model": DiffusionPipeline.from_pretrained(f"damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"),
91
+ "device": "cuda:0"
92
+ },
93
+ # "facebook/maskformer-swin-large-ade": {
94
+ # "model": MaskFormerForInstanceSegmentation.from_pretrained(f"facebook/maskformer-swin-large-ade"),
95
+ # "feature_extractor" : AutoFeatureExtractor.from_pretrained("facebook/maskformer-swin-large-ade"),
96
+ # "device": "cuda:0"
97
+ # },
98
+ # "microsoft/trocr-base-printed": {
99
+ # "processor": TrOCRProcessor.from_pretrained(f"microsoft/trocr-base-printed"),
100
+ # "model": VisionEncoderDecoderModel.from_pretrained(f"microsoft/trocr-base-printed"),
101
+ # "device": "cuda:0"
102
+ # },
103
+ # "microsoft/trocr-base-handwritten": {
104
+ # "processor": TrOCRProcessor.from_pretrained(f"microsoft/trocr-base-handwritten"),
105
+ # "model": VisionEncoderDecoderModel.from_pretrained(f"microsoft/trocr-base-handwritten"),
106
+ # "device": "cuda:0"
107
+ # },
108
+ "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k": {
109
+ "model": BaseModel.from_pretrained("JorisCos/DCCRNet_Libri1Mix_enhsingle_16k"),
110
+ "device": "cuda:0"
111
+ },
112
+ "espnet/kan-bayashi_ljspeech_vits": {
113
+ "model": Text2Speech.from_pretrained(f"espnet/kan-bayashi_ljspeech_vits"),
114
+ "device": "cuda:0"
115
+ },
116
+ "lambdalabs/sd-image-variations-diffusers": {
117
+ "model": DiffusionPipeline.from_pretrained(f"lambdalabs/sd-image-variations-diffusers"), #torch_dtype=torch.float16
118
+ "device": "cuda:0"
119
+ },
120
+ # "CompVis/stable-diffusion-v1-4": {
121
+ # "model": DiffusionPipeline.from_pretrained(f"CompVis/stable-diffusion-v1-4"),
122
+ # "device": "cuda:0"
123
+ # },
124
+ # "stabilityai/stable-diffusion-2-1": {
125
+ # "model": DiffusionPipeline.from_pretrained(f"stabilityai/stable-diffusion-2-1"),
126
+ # "device": "cuda:0"
127
+ # },
128
+ "runwayml/stable-diffusion-v1-5": {
129
+ "model": DiffusionPipeline.from_pretrained(f"runwayml/stable-diffusion-v1-5"),
130
+ "device": "cuda:0"
131
+ },
132
+ # "microsoft/speecht5_tts":{
133
+ # "processor": SpeechT5Processor.from_pretrained(f"microsoft/speecht5_tts"),
134
+ # "model": SpeechT5ForTextToSpeech.from_pretrained(f"microsoft/speecht5_tts"),
135
+ # "vocoder": SpeechT5HifiGan.from_pretrained(f"microsoft/speecht5_hifigan"),
136
+ # "embeddings_dataset": load_dataset(f"Matthijs/cmu-arctic-xvectors", split="validation"),
137
+ # "device": "cuda:0"
138
+ # },
139
+ # "speechbrain/mtl-mimic-voicebank": {
140
+ # "model": WaveformEnhancement.from_hparams(source="speechbrain/mtl-mimic-voicebank", savedir="models/mtl-mimic-voicebank"),
141
+ # "device": "cuda:0"
142
+ # },
143
+ "microsoft/speecht5_vc":{
144
+ "processor": SpeechT5Processor.from_pretrained(f"microsoft/speecht5_vc"),
145
+ "model": SpeechT5ForSpeechToSpeech.from_pretrained(f"microsoft/speecht5_vc"),
146
+ "vocoder": SpeechT5HifiGan.from_pretrained(f"microsoft/speecht5_hifigan"),
147
+ "embeddings_dataset": load_dataset(f"Matthijs/cmu-arctic-xvectors", split="validation"),
148
+ "device": "cuda:0"
149
+ },
150
+ # "julien-c/wine-quality": {
151
+ # "model": joblib.load(cached_download(hf_hub_url("julien-c/wine-quality", "sklearn_model.joblib")))
152
+ # },
153
+ # "facebook/timesformer-base-finetuned-k400": {
154
+ # "processor": AutoImageProcessor.from_pretrained(f"facebook/timesformer-base-finetuned-k400"),
155
+ # "model": TimesformerForVideoClassification.from_pretrained(f"facebook/timesformer-base-finetuned-k400"),
156
+ # "device": "cuda:0"
157
+ # },
158
+ "facebook/maskformer-swin-base-coco": {
159
+ "feature_extractor": MaskFormerFeatureExtractor.from_pretrained(f"facebook/maskformer-swin-base-coco"),
160
+ "model": MaskFormerForInstanceSegmentation.from_pretrained(f"facebook/maskformer-swin-base-coco"),
161
+ "device": "cuda:0"
162
+ },
163
+ "Intel/dpt-hybrid-midas": {
164
+ "model": DPTForDepthEstimation.from_pretrained(f"Intel/dpt-hybrid-midas", low_cpu_mem_usage=True),
165
+ "feature_extractor": DPTFeatureExtractor.from_pretrained(f"Intel/dpt-hybrid-midas"),
166
+ "device": "cuda:0"
167
+ }
168
+ }
169
+
170
+ if local_deployment in ["full", "standard"]:
171
+ standard_pipes = {
172
+ # "superb/wav2vec2-base-superb-ks": {
173
+ # "model": pipeline(task="audio-classification", model=f"superb/wav2vec2-base-superb-ks"),
174
+ # "device": "cuda:0"
175
+ # },
176
+ "openai/whisper-base": {
177
+ "model": pipeline(task="automatic-speech-recognition", model=f"openai/whisper-base"),
178
+ "device": "cuda:0"
179
+ },
180
+ "microsoft/speecht5_asr": {
181
+ "model": pipeline(task="automatic-speech-recognition", model=f"microsoft/speecht5_asr"),
182
+ "device": "cuda:0"
183
+ },
184
+ "Intel/dpt-large": {
185
+ "model": pipeline(task="depth-estimation", model=f"Intel/dpt-large"),
186
+ "device": "cuda:0"
187
+ },
188
+ # "microsoft/beit-base-patch16-224-pt22k-ft22k": {
189
+ # "model": pipeline(task="image-classification", model=f"microsoft/beit-base-patch16-224-pt22k-ft22k"),
190
+ # "device": "cuda:0"
191
+ # },
192
+ "facebook/detr-resnet-50-panoptic": {
193
+ "model": pipeline(task="image-segmentation", model=f"facebook/detr-resnet-50-panoptic"),
194
+ "device": "cuda:0"
195
+ },
196
+ "facebook/detr-resnet-101": {
197
+ "model": pipeline(task="object-detection", model=f"facebook/detr-resnet-101"),
198
+ "device": "cuda:0"
199
+ },
200
+ # "openai/clip-vit-large-patch14": {
201
+ # "model": pipeline(task="zero-shot-image-classification", model=f"openai/clip-vit-large-patch14"),
202
+ # "device": "cuda:0"
203
+ # },
204
+ "google/owlvit-base-patch32": {
205
+ "model": pipeline(task="zero-shot-object-detection", model=f"google/owlvit-base-patch32"),
206
+ "device": "cuda:0"
207
+ },
208
+ # "microsoft/DialoGPT-medium": {
209
+ # "model": pipeline(task="conversational", model=f"microsoft/DialoGPT-medium"),
210
+ # "device": "cuda:0"
211
+ # },
212
+ # "bert-base-uncased": {
213
+ # "model": pipeline(task="fill-mask", model=f"bert-base-uncased"),
214
+ # "device": "cuda:0"
215
+ # },
216
+ # "deepset/roberta-base-squad2": {
217
+ # "model": pipeline(task = "question-answering", model=f"deepset/roberta-base-squad2"),
218
+ # "device": "cuda:0"
219
+ # },
220
+ # "facebook/bart-large-cnn": {
221
+ # "model": pipeline(task="summarization", model=f"facebook/bart-large-cnn"),
222
+ # "device": "cuda:0"
223
+ # },
224
+ # "google/tapas-base-finetuned-wtq": {
225
+ # "model": pipeline(task="table-question-answering", model=f"google/tapas-base-finetuned-wtq"),
226
+ # "device": "cuda:0"
227
+ # },
228
+ # "distilbert-base-uncased-finetuned-sst-2-english": {
229
+ # "model": pipeline(task="text-classification", model=f"distilbert-base-uncased-finetuned-sst-2-english"),
230
+ # "device": "cuda:0"
231
+ # },
232
+ # "gpt2": {
233
+ # "model": pipeline(task="text-generation", model="gpt2"),
234
+ # "device": "cuda:0"
235
+ # },
236
+ # "mrm8488/t5-base-finetuned-question-generation-ap": {
237
+ # "model": pipeline(task="text2text-generation", model=f"mrm8488/t5-base-finetuned-question-generation-ap"),
238
+ # "device": "cuda:0"
239
+ # },
240
+ # "Jean-Baptiste/camembert-ner": {
241
+ # "model": pipeline(task="token-classification", model=f"Jean-Baptiste/camembert-ner", aggregation_strategy="simple"),
242
+ # "device": "cuda:0"
243
+ # },
244
+ # "t5-base": {
245
+ # "model": pipeline(task="translation", model=f"t5-base"),
246
+ # "device": "cuda:0"
247
+ # },
248
+ "impira/layoutlm-document-qa": {
249
+ "model": pipeline(task="document-question-answering", model=f"impira/layoutlm-document-qa"),
250
+ "device": "cuda:0"
251
+ },
252
+ "ydshieh/vit-gpt2-coco-en": {
253
+ "model": pipeline(task="image-to-text", model=f"ydshieh/vit-gpt2-coco-en"),
254
+ "device": "cuda:0"
255
+ },
256
+ "dandelin/vilt-b32-finetuned-vqa": {
257
+ "model": pipeline(task="visual-question-answering", model=f"dandelin/vilt-b32-finetuned-vqa"),
258
+ "device": "cuda:0"
259
+ }
260
+ }
261
+
262
+ if local_deployment in ["full", "standard", "minimal"]:
263
+
264
+ controlnet = ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
265
+ controlnetpipe = StableDiffusionControlNetPipeline.from_pretrained(
266
+ f"{local_models}runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
267
+ )
268
+
269
+ def mlsd_control_network():
270
+ model = MobileV2_MLSD_Large()
271
+ model.load_state_dict(torch.load(f"{local_models}lllyasviel/ControlNet/annotator/ckpts/mlsd_large_512_fp32.pth"), strict=True)
272
+ return MLSDdetector(model)
273
+
274
+
275
+ hed_network = Network(f"{local_models}lllyasviel/ControlNet/annotator/ckpts/network-bsds500.pth")
276
+
277
+ controlnet_sd_pipes = {
278
+ "openpose-control": {
279
+ "model": OpenposeDetector(Body(f"{local_models}lllyasviel/ControlNet/annotator/ckpts/body_pose_model.pth"))
280
+ },
281
+ "mlsd-control": {
282
+ "model": mlsd_control_network()
283
+ },
284
+ "hed-control": {
285
+ "model": HEDdetector(hed_network)
286
+ },
287
+ "scribble-control": {
288
+ "model": HEDdetector(hed_network)
289
+ },
290
+ "midas-control": {
291
+ "model": MidasDetector(model_path=f"{local_models}lllyasviel/ControlNet/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt")
292
+ },
293
+ "canny-control": {
294
+ "model": CannyDetector()
295
+ },
296
+ "lllyasviel/sd-controlnet-canny":{
297
+ "control": controlnet,
298
+ "model": controlnetpipe,
299
+ "device": "cuda:0"
300
+ },
301
+ "lllyasviel/sd-controlnet-depth":{
302
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16),
303
+ "model": controlnetpipe,
304
+ "device": "cuda:0"
305
+ },
306
+ "lllyasviel/sd-controlnet-hed":{
307
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-hed", torch_dtype=torch.float16),
308
+ "model": controlnetpipe,
309
+ "device": "cuda:0"
310
+ },
311
+ "lllyasviel/sd-controlnet-mlsd":{
312
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-mlsd", torch_dtype=torch.float16),
313
+ "model": controlnetpipe,
314
+ "device": "cuda:0"
315
+ },
316
+ "lllyasviel/sd-controlnet-openpose":{
317
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16),
318
+ "model": controlnetpipe,
319
+ "device": "cuda:0"
320
+ },
321
+ "lllyasviel/sd-controlnet-scribble":{
322
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-scribble", torch_dtype=torch.float16),
323
+ "model": controlnetpipe,
324
+ "device": "cuda:0"
325
+ },
326
+ "lllyasviel/sd-controlnet-seg":{
327
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16),
328
+ "model": controlnetpipe,
329
+ "device": "cuda:0"
330
+ }
331
+ }
332
+ pipes = {**standard_pipes, **other_pipes, **controlnet_sd_pipes}
333
+ return pipes
334
+
335
+ pipes = load_pipes(local_deployment)
336
+
337
+ end = time.time()
338
+ during = end - start
339
+
340
+ print(f"[ ready ] {during}s")
341
+
342
+ def running():
343
+ return {"running": True}
344
+
345
+ def status(model_id):
346
+ disabled_models = ["microsoft/trocr-base-printed", "microsoft/trocr-base-handwritten"]
347
+ if model_id in pipes.keys() and model_id not in disabled_models:
348
+ print(f"[ check {model_id} ] success")
349
+ return {"loaded": True}
350
+ else:
351
+ print(f"[ check {model_id} ] failed")
352
+ return {"loaded": False}
353
+
354
+ def models(model_id, data):
355
+ while "using" in pipes[model_id] and pipes[model_id]["using"]:
356
+ print(f"[ inference {model_id} ] waiting")
357
+ time.sleep(0.1)
358
+ pipes[model_id]["using"] = True
359
+ print(f"[ inference {model_id} ] start")
360
+
361
+ start = time.time()
362
+
363
+ pipe = pipes[model_id]["model"]
364
+
365
+ if "device" in pipes[model_id]:
366
+ try:
367
+ pipe.to(pipes[model_id]["device"])
368
+ except:
369
+ pipe.device = torch.device(pipes[model_id]["device"])
370
+ pipe.model.to(pipes[model_id]["device"])
371
+
372
+ result = None
373
+ try:
374
+ # text to video
375
+ if model_id == "damo-vilab/text-to-video-ms-1.7b":
376
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
377
+ # pipe.enable_model_cpu_offload()
378
+ prompt = data["text"]
379
+ video_frames = pipe(prompt, num_inference_steps=50, num_frames=40).frames
380
+ video_path = export_to_video(video_frames)
381
+ file_name = str(uuid.uuid4())[:4]
382
+ os.system(f"LD_LIBRARY_PATH=/usr/local/lib /usr/local/bin/ffmpeg -i {video_path} -vcodec libx264 public/videos/{file_name}.mp4")
383
+ result = {"path": f"/videos/{file_name}.mp4"}
384
+
385
+ # controlnet
386
+ if model_id.startswith("lllyasviel/sd-controlnet-"):
387
+ pipe.controlnet.to('cpu')
388
+ pipe.controlnet = pipes[model_id]["control"].to(pipes[model_id]["device"])
389
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
390
+ control_image = load_image(data["img_url"])
391
+ # generator = torch.manual_seed(66)
392
+ out_image: Image = pipe(data["text"], num_inference_steps=20, image=control_image).images[0]
393
+ file_name = str(uuid.uuid4())[:4]
394
+ out_image.save(f"public/images/{file_name}.png")
395
+ result = {"path": f"/images/{file_name}.png"}
396
+
397
+ if model_id.endswith("-control"):
398
+ image = load_image(data["img_url"])
399
+ if "scribble" in model_id:
400
+ control = pipe(image, scribble = True)
401
+ elif "canny" in model_id:
402
+ control = pipe(image, low_threshold=100, high_threshold=200)
403
+ else:
404
+ control = pipe(image)
405
+ file_name = str(uuid.uuid4())[:4]
406
+ control.save(f"public/images/{file_name}.png")
407
+ result = {"path": f"/images/{file_name}.png"}
408
+
409
+ # image to image
410
+ if model_id == "lambdalabs/sd-image-variations-diffusers":
411
+ im = load_image(data["img_url"])
412
+ file_name = str(uuid.uuid4())[:4]
413
+ with open(f"public/images/{file_name}.png", "wb") as f:
414
+ f.write(data)
415
+ tform = transforms.Compose([
416
+ transforms.ToTensor(),
417
+ transforms.Resize(
418
+ (224, 224),
419
+ interpolation=transforms.InterpolationMode.BICUBIC,
420
+ antialias=False,
421
+ ),
422
+ transforms.Normalize(
423
+ [0.48145466, 0.4578275, 0.40821073],
424
+ [0.26862954, 0.26130258, 0.27577711]),
425
+ ])
426
+ inp = tform(im).to(pipes[model_id]["device"]).unsqueeze(0)
427
+ out = pipe(inp, guidance_scale=3)
428
+ out["images"][0].save(f"public/images/{file_name}.jpg")
429
+ result = {"path": f"/images/{file_name}.jpg"}
430
+
431
+ # image to text
432
+ if model_id == "Salesforce/blip-image-captioning-large":
433
+ raw_image = load_image(data["img_url"]).convert('RGB')
434
+ text = data["text"]
435
+ inputs = pipes[model_id]["processor"](raw_image, return_tensors="pt").to(pipes[model_id]["device"])
436
+ out = pipe.generate(**inputs)
437
+ caption = pipes[model_id]["processor"].decode(out[0], skip_special_tokens=True)
438
+ result = {"generated text": caption}
439
+ if model_id == "ydshieh/vit-gpt2-coco-en":
440
+ img_url = data["img_url"]
441
+ generated_text = pipe(img_url)[0]['generated_text']
442
+ result = {"generated text": generated_text}
443
+ if model_id == "nlpconnect/vit-gpt2-image-captioning":
444
+ image = load_image(data["img_url"]).convert("RGB")
445
+ pixel_values = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt").pixel_values
446
+ pixel_values = pixel_values.to(pipes[model_id]["device"])
447
+ generated_ids = pipe.generate(pixel_values, **{"max_length": 200, "num_beams": 1})
448
+ generated_text = pipes[model_id]["tokenizer"].batch_decode(generated_ids, skip_special_tokens=True)[0]
449
+ result = {"generated text": generated_text}
450
+ # image to text: OCR
451
+ if model_id == "microsoft/trocr-base-printed" or model_id == "microsoft/trocr-base-handwritten":
452
+ image = load_image(data["img_url"]).convert("RGB")
453
+ pixel_values = pipes[model_id]["processor"](image, return_tensors="pt").pixel_values
454
+ pixel_values = pixel_values.to(pipes[model_id]["device"])
455
+ generated_ids = pipe.generate(pixel_values)
456
+ generated_text = pipes[model_id]["processor"].batch_decode(generated_ids, skip_special_tokens=True)[0]
457
+ result = {"generated text": generated_text}
458
+
459
+ # text to image
460
+ if model_id == "runwayml/stable-diffusion-v1-5":
461
+ file_name = str(uuid.uuid4())[:4]
462
+ text = data["text"]
463
+ out = pipe(prompt=text)
464
+ out["images"][0].save(f"public/images/{file_name}.jpg")
465
+ result = {"path": f"/images/{file_name}.jpg"}
466
+
467
+ # object detection
468
+ if model_id == "google/owlvit-base-patch32" or model_id == "facebook/detr-resnet-101":
469
+ img_url = data["img_url"]
470
+ open_types = ["cat", "couch", "person", "car", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird"]
471
+ result = pipe(img_url, candidate_labels=open_types)
472
+
473
+ # VQA
474
+ if model_id == "dandelin/vilt-b32-finetuned-vqa":
475
+ question = data["text"]
476
+ img_url = data["img_url"]
477
+ result = pipe(question=question, image=img_url)
478
+
479
+ #DQA
480
+ if model_id == "impira/layoutlm-document-qa":
481
+ question = data["text"]
482
+ img_url = data["img_url"]
483
+ result = pipe(img_url, question)
484
+
485
+ # depth-estimation
486
+ if model_id == "Intel/dpt-large":
487
+ output = pipe(data["img_url"])
488
+ image = output['depth']
489
+ name = str(uuid.uuid4())[:4]
490
+ image.save(f"public/images/{name}.jpg")
491
+ result = {"path": f"/images/{name}.jpg"}
492
+
493
+ if model_id == "Intel/dpt-hybrid-midas" and model_id == "Intel/dpt-large":
494
+ image = load_image(data["img_url"])
495
+ inputs = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt")
496
+ with torch.no_grad():
497
+ outputs = pipe(**inputs)
498
+ predicted_depth = outputs.predicted_depth
499
+ prediction = torch.nn.functional.interpolate(
500
+ predicted_depth.unsqueeze(1),
501
+ size=image.size[::-1],
502
+ mode="bicubic",
503
+ align_corners=False,
504
+ )
505
+ output = prediction.squeeze().cpu().numpy()
506
+ formatted = (output * 255 / np.max(output)).astype("uint8")
507
+ image = Image.fromarray(formatted)
508
+ name = str(uuid.uuid4())[:4]
509
+ image.save(f"public/images/{name}.jpg")
510
+ result = {"path": f"/images/{name}.jpg"}
511
+
512
+ # TTS
513
+ if model_id == "espnet/kan-bayashi_ljspeech_vits":
514
+ text = data["text"]
515
+ wav = pipe(text)["wav"]
516
+ name = str(uuid.uuid4())[:4]
517
+ sf.write(f"public/audios/{name}.wav", wav.cpu().numpy(), pipe.fs, "PCM_16")
518
+ result = {"path": f"/audios/{name}.wav"}
519
+
520
+ if model_id == "microsoft/speecht5_tts":
521
+ text = data["text"]
522
+ inputs = pipes[model_id]["processor"](text=text, return_tensors="pt")
523
+ embeddings_dataset = pipes[model_id]["embeddings_dataset"]
524
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(pipes[model_id]["device"])
525
+ pipes[model_id]["vocoder"].to(pipes[model_id]["device"])
526
+ speech = pipe.generate_speech(inputs["input_ids"].to(pipes[model_id]["device"]), speaker_embeddings, vocoder=pipes[model_id]["vocoder"])
527
+ name = str(uuid.uuid4())[:4]
528
+ sf.write(f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000)
529
+ result = {"path": f"/audios/{name}.wav"}
530
+
531
+ # ASR
532
+ if model_id == "openai/whisper-base" or model_id == "microsoft/speecht5_asr":
533
+ audio_url = data["audio_url"]
534
+ result = { "text": pipe(audio_url)["text"]}
535
+
536
+ # audio to audio
537
+ if model_id == "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k":
538
+ audio_url = data["audio_url"]
539
+ wav, sr = torchaudio.load(audio_url)
540
+ with torch.no_grad():
541
+ result_wav = pipe(wav.to(pipes[model_id]["device"]))
542
+ name = str(uuid.uuid4())[:4]
543
+ sf.write(f"public/audios/{name}.wav", result_wav.cpu().squeeze().numpy(), sr)
544
+ result = {"path": f"/audios/{name}.wav"}
545
+
546
+ if model_id == "microsoft/speecht5_vc":
547
+ audio_url = data["audio_url"]
548
+ wav, sr = torchaudio.load(audio_url)
549
+ inputs = pipes[model_id]["processor"](audio=wav, sampling_rate=sr, return_tensors="pt")
550
+ embeddings_dataset = pipes[model_id]["embeddings_dataset"]
551
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
552
+ pipes[model_id]["vocoder"].to(pipes[model_id]["device"])
553
+ speech = pipe.generate_speech(inputs["input_ids"].to(pipes[model_id]["device"]), speaker_embeddings, vocoder=pipes[model_id]["vocoder"])
554
+ name = str(uuid.uuid4())[:4]
555
+ sf.write(f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000)
556
+ result = {"path": f"/audios/{name}.wav"}
557
+
558
+ # segmentation
559
+ if model_id == "facebook/detr-resnet-50-panoptic":
560
+ result = []
561
+ segments = pipe(data["img_url"])
562
+ image = load_image(data["img_url"])
563
+
564
+ colors = []
565
+ for i in range(len(segments)):
566
+ colors.append((random.randint(100, 255), random.randint(100, 255), random.randint(100, 255), 50))
567
+
568
+ for segment in segments:
569
+ mask = segment["mask"]
570
+ mask = mask.convert('L')
571
+ layer = Image.new('RGBA', mask.size, colors[i])
572
+ image.paste(layer, (0, 0), mask)
573
+ name = str(uuid.uuid4())[:4]
574
+ image.save(f"public/images/{name}.jpg")
575
+ result = {"path": f"/images/{name}.jpg"}
576
+
577
+ if model_id == "facebook/maskformer-swin-base-coco" or model_id == "facebook/maskformer-swin-large-ade":
578
+ image = load_image(data["img_url"])
579
+ inputs = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt").to(pipes[model_id]["device"])
580
+ outputs = pipe(**inputs)
581
+ result = pipes[model_id]["feature_extractor"].post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
582
+ predicted_panoptic_map = result["segmentation"].cpu().numpy()
583
+ predicted_panoptic_map = Image.fromarray(predicted_panoptic_map.astype(np.uint8))
584
+ name = str(uuid.uuid4())[:4]
585
+ predicted_panoptic_map.save(f"public/images/{name}.jpg")
586
+ result = {"path": f"/images/{name}.jpg"}
587
+
588
+ except Exception as e:
589
+ print(e)
590
+ traceback.print_exc()
591
+ result = {"error": {"message": "Error when running the model inference."}}
592
+
593
+ if "device" in pipes[model_id]:
594
+ try:
595
+ pipe.to("cpu")
596
+ torch.cuda.empty_cache()
597
+ except:
598
+ pipe.device = torch.device("cpu")
599
+ pipe.model.to("cpu")
600
+ torch.cuda.empty_cache()
601
+
602
+ pipes[model_id]["using"] = False
603
+
604
+ if result is None:
605
+ result = {"error": {"message": "model not found"}}
606
+
607
+ end = time.time()
608
+ during = end - start
609
+ print(f"[ complete {model_id} ] {during}s")
610
+ print(f"[ result {model_id} ] {result}")
611
+
612
+ return result
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ tesseract-ocr
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/diffusers.git@8c530fc2f6a76a2aefb6b285dce6df1675092ac6#egg=diffusers
2
+ git+https://github.com/huggingface/transformers@c612628045822f909020f7eb6784c79700813eda#egg=transformers
3
+ git+https://github.com/patrickvonplaten/controlnet_aux@78efc716868a7f5669c288233d65b471f542ce40#egg=controlnet_aux
4
+ tiktoken==0.3.3
5
+ pydub==0.25.1
6
+ espnet==202301
7
+ espnet_model_zoo==0.1.7
8
+ flask==2.2.3
9
+ flask_cors==3.0.10
10
+ waitress==2.1.2
11
+ datasets==2.11.0
12
+ asteroid==0.6.0
13
+ speechbrain==0.5.14
14
+ timm==0.6.13
15
+ typeguard==2.13.3
16
+ accelerate==0.18.0
17
+ pytesseract==0.3.10
18
+ basicsr==1.4.2