Update app.py
Browse files
app.py
CHANGED
@@ -243,6 +243,84 @@ app = gr.Interface(
|
|
243 |
inputs=["text", "checkbox", gr.Slider(0, 100)],
|
244 |
outputs=["text", "number"],
|
245 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
#app.launch()
|
247 |
####################
|
248 |
|
@@ -270,11 +348,12 @@ def topic_sale_inform (text):
|
|
270 |
#conversation = Conversation("Welcome")
|
271 |
|
272 |
def callChains(current_message):
|
|
|
273 |
sentiment_analysis_result = pipeline_predict_sentiment(current_message)
|
274 |
topic_sale_inform_result = topic_sale_inform(current_message)
|
275 |
#conversation.append_response("The Big lebowski.")
|
276 |
#conversation.add_user_input("Is it good?")
|
277 |
-
final_answer = func(current_message)
|
278 |
return final_answer, sentiment_analysis_result, topic_sale_inform_result
|
279 |
|
280 |
|
|
|
243 |
inputs=["text", "checkbox", gr.Slider(0, 100)],
|
244 |
outputs=["text", "number"],
|
245 |
)
|
246 |
+
|
247 |
+
####
|
248 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, TextIteratorStreamer
|
249 |
+
|
250 |
+
model_id = "philschmid/instruct-igel-001"
|
251 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True)
|
252 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
253 |
+
prompt_template = f"### Anweisung:\n{{input}}\n\n### Antwort:"
|
254 |
+
|
255 |
+
def generate(instruction, temperature=1.0, max_new_tokens=256, top_p=0.9, length_penalty=1.0):
|
256 |
+
formatted_instruction = prompt_template.format(input=instruction)
|
257 |
+
|
258 |
+
# make sure temperature top_p and length_penalty are floats
|
259 |
+
temperature = float(temperature)
|
260 |
+
top_p = float(top_p)
|
261 |
+
length_penalty = float(length_penalty)
|
262 |
+
|
263 |
+
# COMMENT IN FOR NON STREAMING
|
264 |
+
# generation_config = GenerationConfig(
|
265 |
+
# do_sample=True,
|
266 |
+
# top_p=top_p,
|
267 |
+
# top_k=0,
|
268 |
+
# temperature=temperature,
|
269 |
+
# max_new_tokens=max_new_tokens,
|
270 |
+
# early_stopping=True,
|
271 |
+
# length_penalty=length_penalty,
|
272 |
+
# eos_token_id=tokenizer.eos_token_id,
|
273 |
+
# pad_token_id=tokenizer.pad_token_id,
|
274 |
+
# )
|
275 |
+
|
276 |
+
# input_ids = tokenizer(
|
277 |
+
# formatted_instruction, return_tensors="pt", truncation=True, max_length=2048
|
278 |
+
# ).input_ids.cuda()
|
279 |
+
|
280 |
+
# with torch.inference_mode(), torch.autocast("cuda"):
|
281 |
+
# outputs = model.generate(input_ids=input_ids, generation_config=generation_config)[0]
|
282 |
+
|
283 |
+
# output = tokenizer.decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)
|
284 |
+
# return output.split("### Antwort:\n")[1]
|
285 |
+
|
286 |
+
# STREAMING BASED ON git+https://github.com/gante/transformers.git@streamer_iterator
|
287 |
+
|
288 |
+
# streaming
|
289 |
+
streamer = TextIteratorStreamer(tokenizer)
|
290 |
+
model_inputs = tokenizer(formatted_instruction, return_tensors="pt", truncation=True, max_length=2048)
|
291 |
+
# move to gpu
|
292 |
+
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
|
293 |
+
|
294 |
+
generate_kwargs = dict(
|
295 |
+
top_p=top_p,
|
296 |
+
top_k=0,
|
297 |
+
temperature=temperature,
|
298 |
+
do_sample=True,
|
299 |
+
max_new_tokens=max_new_tokens,
|
300 |
+
early_stopping=True,
|
301 |
+
length_penalty=length_penalty,
|
302 |
+
eos_token_id=tokenizer.eos_token_id,
|
303 |
+
pad_token_id=tokenizer.eos_token_id,
|
304 |
+
)
|
305 |
+
t = Thread(target=model.generate, kwargs={**dict(model_inputs, streamer=streamer), **generate_kwargs})
|
306 |
+
t.start()
|
307 |
+
|
308 |
+
output = ""
|
309 |
+
hidden_output = ""
|
310 |
+
for new_text in streamer:
|
311 |
+
# skip streaming until new text is available
|
312 |
+
if len(hidden_output) <= len(formatted_instruction):
|
313 |
+
hidden_output += new_text
|
314 |
+
continue
|
315 |
+
# replace eos token
|
316 |
+
if tokenizer.eos_token in new_text:
|
317 |
+
new_text = new_text.replace(tokenizer.eos_token, "")
|
318 |
+
output += new_text
|
319 |
+
yield output
|
320 |
+
# if HF_TOKEN:
|
321 |
+
# save_inputs_and_outputs(formatted_instruction, output, generate_kwargs)
|
322 |
+
return output
|
323 |
+
|
324 |
#app.launch()
|
325 |
####################
|
326 |
|
|
|
348 |
#conversation = Conversation("Welcome")
|
349 |
|
350 |
def callChains(current_message):
|
351 |
+
final_answer = generate(current_message, 1.0, 256, 0.9, 1.0)
|
352 |
sentiment_analysis_result = pipeline_predict_sentiment(current_message)
|
353 |
topic_sale_inform_result = topic_sale_inform(current_message)
|
354 |
#conversation.append_response("The Big lebowski.")
|
355 |
#conversation.add_user_input("Is it good?")
|
356 |
+
#final_answer = func(current_message)
|
357 |
return final_answer, sentiment_analysis_result, topic_sale_inform_result
|
358 |
|
359 |
|