Omnibus commited on
Commit
1fcc518
·
verified ·
1 Parent(s): 3dea939

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -4
app.py CHANGED
@@ -103,7 +103,7 @@ def compress_history(formatted_prompt):
103
  return output
104
 
105
 
106
- def question_generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1028, top_p=0.95, repetition_penalty=1.0,):
107
  #def question_generate(prompt, history):
108
  print("###############\nRUNNING QUESTION GENERATOR\n###############\n")
109
  seed = random.randint(1,1111111111111111)
@@ -136,7 +136,7 @@ def question_generate(prompt, history, agent_name=agents[0], sys_prompt="", temp
136
 
137
  return output
138
 
139
- def blog_poster_reply(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,):
140
  #def question_generate(prompt, history):
141
  print("###############\nRUNNING BLOG POSTER REPLY\n###############\n")
142
  seed = random.randint(1,1111111111111111)
@@ -221,6 +221,115 @@ def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0
221
  title=""
222
  filename=create_valid_filename(f'{current_time}---{title}')
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  current_time=current_time.replace(":","-")
225
  current_time=current_time.replace(".","-")
226
  print (current_time)
@@ -354,9 +463,14 @@ with gr.Blocks() as app:
354
  chatbot=gr.Chatbot()
355
  msg = gr.Textbox()
356
  with gr.Row():
357
- submit_b = gr.Button()
 
 
 
 
358
  stop_b = gr.Button("Stop")
359
  clear = gr.ClearButton([msg, chatbot])
 
360
  with gr.Row():
361
  m_choice=gr.Dropdown(label="Models",type='index',choices=[c for c in models],value=models[0],interactive=True)
362
  tokens = gr.Slider(label="Max new tokens",value=1600,minimum=0,maximum=8000,step=64,interactive=True, visible=True,info="The maximum number of tokens")
@@ -371,7 +485,9 @@ with gr.Blocks() as app:
371
  app.load(load_models,m_choice,[chatbot]).then(load_html,None,html)
372
 
373
  sub_b = submit_b.click(generate, [msg,chatbot,tokens],[msg,chatbot,sumbox,sum_out_box,hist_out_box,html])
 
 
374
  sub_e = msg.submit(generate, [msg, chatbot,tokens], [msg, chatbot,sumbox,sum_out_box,hist_out_box,html])
375
- stop_b.click(None,None,None, cancels=[sub_b,sub_e])
376
 
377
  app.queue(default_concurrency_limit=20).launch()
 
103
  return output
104
 
105
 
106
+ def comment_generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1028, top_p=0.95, repetition_penalty=1.0,):
107
  #def question_generate(prompt, history):
108
  print("###############\nRUNNING QUESTION GENERATOR\n###############\n")
109
  seed = random.randint(1,1111111111111111)
 
136
 
137
  return output
138
 
139
+ def reply_generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,):
140
  #def question_generate(prompt, history):
141
  print("###############\nRUNNING BLOG POSTER REPLY\n###############\n")
142
  seed = random.randint(1,1111111111111111)
 
221
  title=""
222
  filename=create_valid_filename(f'{current_time}---{title}')
223
 
224
+ current_time=current_time.replace(":","-")
225
+ current_time=current_time.replace(".","-")
226
+ print (current_time)
227
+ agent=prompts.BLOG_POSTER
228
+ system_prompt=agent
229
+ temperature = float(temperature)
230
+ if temperature < 1e-2:
231
+ temperature = 1e-2
232
+ top_p = float(top_p)
233
+ hist_out=[]
234
+ sum_out=[]
235
+ json_hist={}
236
+ json_obj={}
237
+ full_conv=[]
238
+ post_cnt=1
239
+ seed = random.randint(1,1111111111111111)
240
+ if not history:
241
+ generate_kwargs = dict(
242
+ temperature=temperature,
243
+ max_new_tokens=max_new_tokens2,
244
+ top_p=top_p,
245
+ repetition_penalty=repetition_penalty,
246
+ do_sample=True,
247
+ seed=seed,
248
+ )
249
+ if prompt.startswith(' \"'):
250
+ prompt=prompt.strip(' \"')
251
+
252
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
253
+
254
+ if len(formatted_prompt) < (40000):
255
+ print(len(formatted_prompt))
256
+
257
+ client=client_z[0]
258
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
259
+ output = ""
260
+ #if history:
261
+ # yield history
262
+
263
+ for response in stream:
264
+ output += response.token.text
265
+ yield '', [(prompt,output)],summary[0],json_obj, json_hist,html_out
266
+
267
+ if not title:
268
+ for line in output.split("\n"):
269
+ if "title" in line.lower() and ":" in line.lower():
270
+ title = line.split(":")[1]
271
+ print(f'title:: {title}')
272
+ filename=create_valid_filename(f'{current_time}---{title}')
273
+
274
+ out_json = {"prompt":prompt,"output":output}
275
+
276
+ hist_out.append(out_json)
277
+ #try:
278
+ # for ea in
279
+ with open(f'{uid}.json', 'w') as f:
280
+ json_hist=json.dumps(hist_out, indent=4)
281
+ f.write(json_hist)
282
+ f.close()
283
+
284
+ upload_file(
285
+ path_or_fileobj =f"{uid}.json",
286
+ path_in_repo = f"book1/{filename}.json",
287
+ repo_id =f"{username}/{dataset_name}",
288
+ repo_type = "dataset",
289
+ token=token,
290
+ )
291
+ else:
292
+ formatted_prompt = format_prompt(f"{prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0])}, {summary[0]}", history)
293
+
294
+ #current_time = str(datetime.datetime.now().timestamp()).split(".",1)[0]
295
+ #filename=f'{filename}-{current_time}'
296
+ history = []
297
+ output = compress_history(formatted_prompt)
298
+ summary[0]=output
299
+ sum_json = {"summary":summary[0]}
300
+ sum_out.append(sum_json)
301
+ with open(f'{uid}-sum.json', 'w') as f:
302
+ json_obj=json.dumps(sum_out, indent=4)
303
+ f.write(json_obj)
304
+ f.close()
305
+ upload_file(
306
+ path_or_fileobj =f"{uid}-sum.json",
307
+ path_in_repo = f"book1/{filename}-summary.json",
308
+ repo_id =f"{username}/{dataset_name}",
309
+ repo_type = "dataset",
310
+ token=token,
311
+ )
312
+
313
+
314
+ #prompt = question_generate(output, history)
315
+ #main_point[0]=prompt
316
+ #full_conv.append((output,prompt))
317
+
318
+
319
+ html_out=load_html(full_conv,title)
320
+ yield prompt, history, summary[0],json_obj,json_hist,html_out
321
+ else:
322
+ pass
323
+
324
+ def generate_OG(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1048, top_p=0.95, repetition_penalty=1.0):
325
+ html_out=""
326
+ #main_point[0]=prompt
327
+ #print(datetime.datetime.now())
328
+ uid=uuid.uuid4()
329
+ current_time = str(datetime.datetime.now())
330
+ title=""
331
+ filename=create_valid_filename(f'{current_time}---{title}')
332
+
333
  current_time=current_time.replace(":","-")
334
  current_time=current_time.replace(".","-")
335
  print (current_time)
 
463
  chatbot=gr.Chatbot()
464
  msg = gr.Textbox()
465
  with gr.Row():
466
+
467
+ submit_b = gr.Button("Blog Post")
468
+ submit_c = gr.Button("Comment")
469
+ submit_r = gr.Button("OP Reply")
470
+ with gr.Row():
471
  stop_b = gr.Button("Stop")
472
  clear = gr.ClearButton([msg, chatbot])
473
+
474
  with gr.Row():
475
  m_choice=gr.Dropdown(label="Models",type='index',choices=[c for c in models],value=models[0],interactive=True)
476
  tokens = gr.Slider(label="Max new tokens",value=1600,minimum=0,maximum=8000,step=64,interactive=True, visible=True,info="The maximum number of tokens")
 
485
  app.load(load_models,m_choice,[chatbot]).then(load_html,None,html)
486
 
487
  sub_b = submit_b.click(generate, [msg,chatbot,tokens],[msg,chatbot,sumbox,sum_out_box,hist_out_box,html])
488
+ sub_c = submit_b.click(comment_generate, [msg,chatbot,tokens],[msg,chatbot,sumbox,sum_out_box,hist_out_box,html])
489
+ sub_r = submit_b.click(reply_generate, [msg,chatbot,tokens],[msg,chatbot,sumbox,sum_out_box,hist_out_box,html])
490
  sub_e = msg.submit(generate, [msg, chatbot,tokens], [msg, chatbot,sumbox,sum_out_box,hist_out_box,html])
491
+ stop_b.click(None,None,None, cancels=[sub_b,sub_e,sub_c,sub_r])
492
 
493
  app.queue(default_concurrency_limit=20).launch()