adowu's picture
Update app.py
0477f96 verified
import datetime
import os
import random
import re
from io import StringIO
import gradio as gr
import pandas as pd
from huggingface_hub import upload_file
from text_generation import Client
HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_TOKEN = os.environ.get("API_TOKEN", None)
DIALOGUES_DATASET = "HuggingFaceH4/starchat_playground_dialogues"
model2endpoint = {
"starchat-alpha": "https://api-inference.huggingface.co/models/HuggingFaceH4/starcoderbase-finetuned-oasst1",
"starchat-beta": "https://api-inference.huggingface.co/models/HuggingFaceH4/starchat-beta",
}
model_names = list(model2endpoint.keys())
def randomize_seed_generator():
seed = random.randint(0, 1000000)
return seed
def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs, model):
buffer = StringIO()
timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")
file_name = f"prompts_{timestamp}.jsonl"
data = {"model": model, "inputs": inputs, "outputs": outputs, "generate_kwargs": generate_kwargs}
pd.DataFrame([data]).to_json(buffer, orient="records", lines=True)
# Push to Hub
upload_file(
path_in_repo=f"{now.date()}/{now.hour}/{file_name}",
path_or_fileobj=buffer.getvalue().encode(),
repo_id=DIALOGUES_DATASET,
token=HF_TOKEN,
repo_type="dataset",
)
# Clean and rerun
buffer.close()
def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep):
past = []
for data in chatbot:
user_data, model_data = data
if not user_data.startswith(user_name):
user_data = user_name + user_data
if not model_data.startswith(sep + assistant_name):
model_data = sep + assistant_name + model_data
past.append(user_data + model_data.rstrip() + sep)
if not inputs.startswith(user_name):
inputs = user_name + inputs
total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip()
return total_inputs
def wrap_html_code(text):
pattern = r"<.*?>"
matches = re.findall(pattern, text)
if len(matches) > 0:
return f"```{text}```"
else:
return text
def has_no_history(chatbot, history):
return not chatbot and not history
def generate(
RETRY_FLAG,
model_name,
system_message,
user_message,
chatbot,
history,
temperature,
top_k,
top_p,
max_new_tokens,
repetition_penalty,
do_save=True,
):
client = Client(
model2endpoint[model_name],
headers={"Authorization": f"Bearer {API_TOKEN}"},
timeout=60,
)
# Don't return meaningless message when the input is empty
if not user_message:
print("Empty input")
if not RETRY_FLAG:
history.append(user_message)
seed = 42
else:
seed = randomize_seed_generator()
past_messages = []
for data in chatbot:
user_data, model_data = data
past_messages.extend(
[{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}]
)
generate_kwargs = {
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
truncate=4096,
seed=seed,
stop_sequences=["<|end|>"],
)
stream = client.generate_stream(
system_message,
**generate_kwargs,
)
output = ""
for idx, response in enumerate(stream):
if response.token.special:
continue
output += response.token.text
if idx == 0:
history.append(" " + output)
else:
history[-1] = output
chat = [
(wrap_html_code(history[i].strip()), wrap_html_code(history[i + 1].strip()))
for i in range(0, len(history) - 1, 2)
]
# chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)]
yield chat, history, user_message, ""
if HF_TOKEN and do_save:
try:
now = datetime.datetime.now()
current_time = now.strftime("%Y-%m-%d %H:%M:%S")
print(f"[{current_time}] Pushing prompt and completion to the Hub")
save_inputs_and_outputs(now, prompt, output, generate_kwargs, model_name)
except Exception as e:
print(e)
return chat, history, user_message, ""
def clear_chat():
return [], []
def delete_last_turn(chat, history):
if chat and history:
chat.pop(-1)
history.pop(-1)
history.pop(-1)
return chat, history
def process_example(args):
for [x, y] in generate(args):
pass
return [x, y]
# Regenerate response
def retry_last_answer(
selected_model,
system_message,
user_message,
chat,
history,
temperature,
top_k,
top_p,
max_new_tokens,
repetition_penalty,
do_save,
):
if chat and history:
# Removing the previous conversation from chat
chat.pop(-1)
# Removing bot response from the history
history.pop(-1)
# Setting up a flag to capture a retry
RETRY_FLAG = True
# Getting last message from user
user_message = history[-1]
yield from generate(
RETRY_FLAG,
selected_model,
system_message,
user_message,
chat,
history,
temperature,
top_k,
top_p,
max_new_tokens,
repetition_penalty,
do_save,
)
with gr.Blocks(analytics_enabled=False) as demo:
with gr.Row():
do_save = gr.Checkbox(
value=True,
label="Store data",
info="You agree to the storage of your prompt and generated text for research and development purposes:",
)
with gr.Row():
selected_model = gr.Radio(choices=model_names, value=model_names[1], label="Select a model")
with gr.Accordion(label="System Prompt", open=False, elem_id="parameters-accordion"):
system_message = gr.Textbox(
elem_id="system-message",
placeholder="Below is a conversation between a human user and a helpful AI coding assistant.",
show_label=False,
)
with gr.Row():
with gr.Box():
output = gr.Markdown()
chatbot = gr.Chatbot(elem_id="chat-message", label="Chat")
with gr.Row():
with gr.Column(scale=3):
user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input")
with gr.Row():
send_button = gr.Button("Send", elem_id="send-btn", visible=True)
regenerate_button = gr.Button("Regenerate", elem_id="retry-btn", visible=True)
delete_turn_button = gr.Button("Delete last turn", elem_id="delete-btn", visible=True)
clear_chat_button = gr.Button("Clear chat", elem_id="clear-btn", visible=True)
with gr.Accordion(label="Parameters", open=False, elem_id="parameters-accordion"):
temperature = gr.Slider(
label="Temperature",
value=0.2,
minimum=0.0,
maximum=1.0,
step=0.1,
interactive=True,
info="Higher values produce more diverse outputs",
)
top_k = gr.Slider(
label="Top-k",
value=50,
minimum=0.0,
maximum=100,
step=1,
interactive=True,
info="Sample from a shortlist of top-k tokens",
)
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=0.95,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
)
max_new_tokens = gr.Slider(
label="Max new tokens",
value=512,
minimum=0,
maximum=32000,
step=4,
interactive=True,
info="The maximum numbers of new tokens",
)
repetition_penalty = gr.Slider(
label="Repetition Penalty",
value=1.2,
minimum=0.0,
maximum=10,
step=0.1,
interactive=True,
info="The parameter for repetition penalty. 1.0 means no penalty.",
)
history = gr.State([])
RETRY_FLAG = gr.Checkbox(value=False, visible=False)
# To clear out "message" input textbox and use this to regenerate message
last_user_message = gr.State("")
user_message.submit(
generate,
inputs=[
RETRY_FLAG,
selected_model,
system_message,
user_message,
chatbot,
history,
temperature,
top_k,
top_p,
max_new_tokens,
repetition_penalty,
do_save,
],
outputs=[chatbot, history, last_user_message, user_message],
)
send_button.click(
generate,
inputs=[
RETRY_FLAG,
selected_model,
system_message,
user_message,
chatbot,
history,
temperature,
top_k,
top_p,
max_new_tokens,
repetition_penalty,
do_save,
],
outputs=[chatbot, history, last_user_message, user_message],
)
regenerate_button.click(
retry_last_answer,
inputs=[
selected_model,
system_message,
user_message,
chatbot,
history,
temperature,
top_k,
top_p,
max_new_tokens,
repetition_penalty,
do_save,
],
outputs=[chatbot, history, last_user_message, user_message],
)
delete_turn_button.click(delete_last_turn, [chatbot, history], [chatbot, history])
clear_chat_button.click(clear_chat, outputs=[chatbot, history])
selected_model.change(clear_chat, outputs=[chatbot, history])
demo.queue(concurrency_count=16).launch(debug=True)