Update app.py
Browse files
app.py
CHANGED
@@ -7,13 +7,11 @@ import json
|
|
7 |
import uuid
|
8 |
import os
|
9 |
|
10 |
-
|
11 |
-
|
12 |
token=os.environ.get("HF_TOKEN")
|
13 |
username="omnibus"
|
14 |
dataset_name="tmp"
|
15 |
api=HfApi(token="")
|
16 |
-
|
17 |
|
18 |
history = []
|
19 |
hist_out= []
|
@@ -22,6 +20,30 @@ main_point=[]
|
|
22 |
summary.append("")
|
23 |
main_point.append("")
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def format_prompt(message, history):
|
26 |
prompt = "<s>"
|
27 |
for user_prompt, bot_response in history:
|
@@ -64,6 +86,7 @@ def compress_history(formatted_prompt):
|
|
64 |
#history.append((prompt,""))
|
65 |
#formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
|
66 |
formatted_prompt = formatted_prompt
|
|
|
67 |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
68 |
output = ""
|
69 |
|
@@ -96,7 +119,10 @@ def question_generate(prompt, history, agent_name=agents[0], sys_prompt="", temp
|
|
96 |
)
|
97 |
#history.append((prompt,""))
|
98 |
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
|
|
|
|
|
99 |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
|
|
100 |
output = ""
|
101 |
|
102 |
for response in stream:
|
@@ -126,6 +152,7 @@ def blog_poster_reply(prompt, history, agent_name=agents[0], sys_prompt="", temp
|
|
126 |
)
|
127 |
#history.append((prompt,""))
|
128 |
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
|
|
|
129 |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
130 |
output = ""
|
131 |
|
@@ -180,7 +207,7 @@ def load_html(inp,title):
|
|
180 |
|
181 |
|
182 |
|
183 |
-
def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1048, top_p=0.95, repetition_penalty=1.0,):
|
184 |
html_out=""
|
185 |
#main_point[0]=prompt
|
186 |
#print(datetime.datetime.now())
|
@@ -241,7 +268,7 @@ def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0
|
|
241 |
if len(formatted_prompt) < (40000):
|
242 |
print(len(formatted_prompt))
|
243 |
|
244 |
-
|
245 |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
246 |
output = ""
|
247 |
#if history:
|
@@ -325,16 +352,19 @@ with gr.Blocks() as app:
|
|
325 |
submit_b = gr.Button()
|
326 |
stop_b = gr.Button("Stop")
|
327 |
clear = gr.ClearButton([msg, chatbot])
|
|
|
|
|
328 |
sumbox=gr.Textbox("Summary", max_lines=100)
|
329 |
with gr.Column():
|
330 |
sum_out_box=gr.JSON(label="Summaries")
|
331 |
hist_out_box=gr.JSON(label="History")
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
|
|
|
|
|
|
|
338 |
|
339 |
-
app.load(load_html,None,html)
|
340 |
app.queue(default_concurrency_limit=20).launch()
|
|
|
7 |
import uuid
|
8 |
import os
|
9 |
|
|
|
|
|
10 |
token=os.environ.get("HF_TOKEN")
|
11 |
username="omnibus"
|
12 |
dataset_name="tmp"
|
13 |
api=HfApi(token="")
|
14 |
+
VERBOSE=False
|
15 |
|
16 |
history = []
|
17 |
hist_out= []
|
|
|
20 |
summary.append("")
|
21 |
main_point.append("")
|
22 |
|
23 |
+
models=[
|
24 |
+
"google/gemma-7b",
|
25 |
+
"google/gemma-7b-it",
|
26 |
+
"google/gemma-2b",
|
27 |
+
"google/gemma-2b-it",
|
28 |
+
"meta-llama/Llama-2-7b-chat-hf",
|
29 |
+
"codellama/CodeLlama-70b-Instruct-hf",
|
30 |
+
"openchat/openchat-3.5-0106",
|
31 |
+
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
|
32 |
+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
33 |
+
"mistralai/Mixtral-8x7B-Instruct-v0.2",
|
34 |
+
]
|
35 |
+
|
36 |
+
client_z=[]
|
37 |
+
|
38 |
+
def load_models(inp):
|
39 |
+
if VERBOSE==True:
|
40 |
+
print(type(inp))
|
41 |
+
print(inp)
|
42 |
+
print(models[inp])
|
43 |
+
client_z.clear()
|
44 |
+
client_z.append(InferenceClient(models[inp]))
|
45 |
+
return gr.update(label=models[inp])
|
46 |
+
|
47 |
def format_prompt(message, history):
|
48 |
prompt = "<s>"
|
49 |
for user_prompt, bot_response in history:
|
|
|
86 |
#history.append((prompt,""))
|
87 |
#formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
|
88 |
formatted_prompt = formatted_prompt
|
89 |
+
client=client_z[0]
|
90 |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
91 |
output = ""
|
92 |
|
|
|
119 |
)
|
120 |
#history.append((prompt,""))
|
121 |
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
|
122 |
+
client=client_z[0]
|
123 |
+
|
124 |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
125 |
+
|
126 |
output = ""
|
127 |
|
128 |
for response in stream:
|
|
|
152 |
)
|
153 |
#history.append((prompt,""))
|
154 |
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
|
155 |
+
client=client_z[0]
|
156 |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
157 |
output = ""
|
158 |
|
|
|
207 |
|
208 |
|
209 |
|
210 |
+
def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1048, top_p=0.95, repetition_penalty=1.0, m_choice):
|
211 |
html_out=""
|
212 |
#main_point[0]=prompt
|
213 |
#print(datetime.datetime.now())
|
|
|
268 |
if len(formatted_prompt) < (40000):
|
269 |
print(len(formatted_prompt))
|
270 |
|
271 |
+
client=client_z[0]
|
272 |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
273 |
output = ""
|
274 |
#if history:
|
|
|
352 |
submit_b = gr.Button()
|
353 |
stop_b = gr.Button("Stop")
|
354 |
clear = gr.ClearButton([msg, chatbot])
|
355 |
+
m_choice=gr.Dropdown(label="Models",type='index',choices=[c for c in models],value=models[0],interactive=True)
|
356 |
+
|
357 |
sumbox=gr.Textbox("Summary", max_lines=100)
|
358 |
with gr.Column():
|
359 |
sum_out_box=gr.JSON(label="Summaries")
|
360 |
hist_out_box=gr.JSON(label="History")
|
361 |
+
|
362 |
+
|
363 |
+
client_choice.change(load_models,client_choice,[chat_b])
|
364 |
+
app.load(load_models,client_choice,[chat_b]).then(load_html,None,html)
|
|
|
365 |
|
366 |
+
sub_b = submit_b.click(generate, [msg,chatbot],[msg,chatbot,sumbox,sum_out_box,hist_out_box,html,m_choice])
|
367 |
+
sub_e = msg.submit(generate, [msg, chatbot], [msg, chatbot,sumbox,sum_out_box,hist_out_box,html,m_choice])
|
368 |
+
stop_b.click(None,None,None, cancels=[sub_b,sub_e])
|
369 |
|
|
|
370 |
app.queue(default_concurrency_limit=20).launch()
|