Spaces:
Paused
Paused
Commit
·
9a81d74
1
Parent(s):
1abd311
adding text streaming
Browse files
app.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
|
| 5 |
|
| 6 |
token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
| 7 |
|
|
@@ -51,15 +51,17 @@ def get_prompt_with_template(message: str) -> str:
|
|
| 51 |
def generate_model_response(message: str) -> str:
|
| 52 |
prompt = get_prompt_with_template(message)
|
| 53 |
inputs = tokenizer(prompt, return_tensors='pt')
|
|
|
|
| 54 |
if torch.cuda.is_available():
|
| 55 |
inputs = inputs.to('cuda')
|
| 56 |
# Include **generate_kwargs to include the user-defined options
|
| 57 |
output = model.generate(**inputs,
|
| 58 |
max_new_tokens=4096,
|
| 59 |
do_sample=True,
|
| 60 |
-
temperature=0.1
|
|
|
|
| 61 |
)
|
| 62 |
-
return tokenizer.decode(output[0], skip_special_tokens=True)
|
| 63 |
|
| 64 |
def extract_response_content(full_response: str) -> str:
|
| 65 |
response_start_index = full_response.find("### Assistant:")
|
|
|
|
| 1 |
import os
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline, TextStreamer
|
| 5 |
|
| 6 |
token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
| 7 |
|
|
|
|
| 51 |
def generate_model_response(message: str) -> str:
|
| 52 |
prompt = get_prompt_with_template(message)
|
| 53 |
inputs = tokenizer(prompt, return_tensors='pt')
|
| 54 |
+
streamer = TextStreamer(tokenizer)
|
| 55 |
if torch.cuda.is_available():
|
| 56 |
inputs = inputs.to('cuda')
|
| 57 |
# Include **generate_kwargs to include the user-defined options
|
| 58 |
output = model.generate(**inputs,
|
| 59 |
max_new_tokens=4096,
|
| 60 |
do_sample=True,
|
| 61 |
+
temperature=0.1,
|
| 62 |
+
streamer=streamer
|
| 63 |
)
|
| 64 |
+
# return tokenizer.decode(output[0], skip_special_tokens=True)
|
| 65 |
|
| 66 |
def extract_response_content(full_response: str) -> str:
|
| 67 |
response_start_index = full_response.find("### Assistant:")
|