Chris4K commited on
Commit
a0dfe96
1 Parent(s): 2bca17a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -1
app.py CHANGED
@@ -243,6 +243,84 @@ app = gr.Interface(
243
  inputs=["text", "checkbox", gr.Slider(0, 100)],
244
  outputs=["text", "number"],
245
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  #app.launch()
247
  ####################
248
 
@@ -270,11 +348,12 @@ def topic_sale_inform (text):
270
  #conversation = Conversation("Welcome")
271
 
272
  def callChains(current_message):
 
273
  sentiment_analysis_result = pipeline_predict_sentiment(current_message)
274
  topic_sale_inform_result = topic_sale_inform(current_message)
275
  #conversation.append_response("The Big lebowski.")
276
  #conversation.add_user_input("Is it good?")
277
- final_answer = func(current_message)
278
  return final_answer, sentiment_analysis_result, topic_sale_inform_result
279
 
280
 
 
243
  inputs=["text", "checkbox", gr.Slider(0, 100)],
244
  outputs=["text", "number"],
245
  )
246
+
247
+ ####
248
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, TextIteratorStreamer
249
+
250
+ model_id = "philschmid/instruct-igel-001"
251
+ model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True)
252
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
253
+ prompt_template = f"### Anweisung:\n{{input}}\n\n### Antwort:"
254
+
255
+ def generate(instruction, temperature=1.0, max_new_tokens=256, top_p=0.9, length_penalty=1.0):
256
+ formatted_instruction = prompt_template.format(input=instruction)
257
+
258
+ # make sure temperature top_p and length_penalty are floats
259
+ temperature = float(temperature)
260
+ top_p = float(top_p)
261
+ length_penalty = float(length_penalty)
262
+
263
+ # COMMENT IN FOR NON STREAMING
264
+ # generation_config = GenerationConfig(
265
+ # do_sample=True,
266
+ # top_p=top_p,
267
+ # top_k=0,
268
+ # temperature=temperature,
269
+ # max_new_tokens=max_new_tokens,
270
+ # early_stopping=True,
271
+ # length_penalty=length_penalty,
272
+ # eos_token_id=tokenizer.eos_token_id,
273
+ # pad_token_id=tokenizer.pad_token_id,
274
+ # )
275
+
276
+ # input_ids = tokenizer(
277
+ # formatted_instruction, return_tensors="pt", truncation=True, max_length=2048
278
+ # ).input_ids.cuda()
279
+
280
+ # with torch.inference_mode(), torch.autocast("cuda"):
281
+ # outputs = model.generate(input_ids=input_ids, generation_config=generation_config)[0]
282
+
283
+ # output = tokenizer.decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)
284
+ # return output.split("### Antwort:\n")[1]
285
+
286
+ # STREAMING BASED ON git+https://github.com/gante/transformers.git@streamer_iterator
287
+
288
+ # streaming
289
+ streamer = TextIteratorStreamer(tokenizer)
290
+ model_inputs = tokenizer(formatted_instruction, return_tensors="pt", truncation=True, max_length=2048)
291
+ # move to gpu
292
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
293
+
294
+ generate_kwargs = dict(
295
+ top_p=top_p,
296
+ top_k=0,
297
+ temperature=temperature,
298
+ do_sample=True,
299
+ max_new_tokens=max_new_tokens,
300
+ early_stopping=True,
301
+ length_penalty=length_penalty,
302
+ eos_token_id=tokenizer.eos_token_id,
303
+ pad_token_id=tokenizer.eos_token_id,
304
+ )
305
+ t = Thread(target=model.generate, kwargs={**dict(model_inputs, streamer=streamer), **generate_kwargs})
306
+ t.start()
307
+
308
+ output = ""
309
+ hidden_output = ""
310
+ for new_text in streamer:
311
+ # skip streaming until new text is available
312
+ if len(hidden_output) <= len(formatted_instruction):
313
+ hidden_output += new_text
314
+ continue
315
+ # replace eos token
316
+ if tokenizer.eos_token in new_text:
317
+ new_text = new_text.replace(tokenizer.eos_token, "")
318
+ output += new_text
319
+ yield output
320
+ # if HF_TOKEN:
321
+ # save_inputs_and_outputs(formatted_instruction, output, generate_kwargs)
322
+ return output
323
+
324
  #app.launch()
325
  ####################
326
 
 
348
  #conversation = Conversation("Welcome")
349
 
350
  def callChains(current_message):
351
+ final_answer = generate(current_message, 1.0, 256, 0.9, 1.0)
352
  sentiment_analysis_result = pipeline_predict_sentiment(current_message)
353
  topic_sale_inform_result = topic_sale_inform(current_message)
354
  #conversation.append_response("The Big lebowski.")
355
  #conversation.add_user_input("Is it good?")
356
+ #final_answer = func(current_message)
357
  return final_answer, sentiment_analysis_result, topic_sale_inform_result
358
 
359