Spaces:
Runtime error
Runtime error
File size: 3,831 Bytes
a1a543e a97bf6b f7f857f a1a543e 588b2d4 a1a543e 337c211 a1a543e a97bf6b a1a543e a97bf6b a1a543e f7f857f 7bb9775 c3cbdc6 a1a543e c3cbdc6 a1a543e c3cbdc6 f7f857f a1a543e f7f857f a1a543e f7f857f a1a543e f7f857f a1a543e f7f857f a1a543e 588b2d4 f7f857f b147c55 a1a543e 336d41b f7f857f a1a543e f7f857f a1a543e c3cbdc6 a1a543e f7f857f a1a543e f7f857f a1a543e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
from threading import Thread
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import time
model_id = "EleutherAI/pythia-6.9b-deduped"
assistant_id = "EleutherAI/pythia-70m-deduped"
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", torch_device)
print("CPU threads:", torch.get_num_threads())
if torch_device == "cuda":
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True, device_map="auto")
else:
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_id).to(torch_device)
def run_generation(user_text, use_assistant, temperature, max_new_tokens):
if temperature < 0.1:
do_sample = False
else:
do_sample = True
# Get the model and tokenizer, and tokenize the user text.
model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
# in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
assistant_model=assistant_model if use_assistant else None,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
top_p=0.95,
temperature=float(temperature),
top_k=50,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
start = time.time()
t.start()
# Pull the generated text from the streamer, and update the model output. Return the model output and time
# spent so far.
model_output = ""
for new_text in streamer:
model_output += new_text
yield [model_output, round(time.time() - start, 3)]
return [model_output, round(time.time() - start, 3)]
def reset_textbox():
return gr.update(value='')
with gr.Blocks() as demo:
gr.Markdown(
"# 🤗 Assisted Generation Demo\n"
f"- Model: {model_id} (using INT8)\n"
f"- Assistant Model: {assistant_id}\n"
"- Disclaimer: due to INT8 quantization and the use of causal masking in assisted generation, the output "
"of greedy decoding may differ in rare occasions."
)
with gr.Row():
with gr.Column(scale=4):
user_text = gr.Textbox(
placeholder="A sequence: one, two, three, ",
label="Prompt"
)
model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
button_submit = gr.Button(value="Submit")
with gr.Column(scale=1, min_width=200):
gr.Markdown("### Generation Settings")
use_assistant = gr.Checkbox(label="Use Assisted Generation", value=True)
max_new_tokens = gr.Slider(
minimum=1, maximum=500, value=100, step=1, interactive=True, label="Max New Tokens",
)
temperature = gr.Slider(
minimum=0.0, maximum=2.0, value=0.0, step=0.1, interactive=True, label="Temperature (0.0 = Greedy)",
)
gr.Markdown("### Generation time (seconds)")
generation_time = gr.Textbox(lines=1, interactive=False, show_label=False)
generate_inputs = [user_text, use_assistant, temperature, max_new_tokens]
generate_outputs = [model_output, generation_time]
user_text.submit(run_generation, generate_inputs, generate_outputs)
button_submit.click(run_generation, generate_inputs, generate_outputs)
demo.queue(max_size=32).launch(enable_queue=True)
|