Spaces:
Running
Running
import gradio as gr | |
from openai import OpenAI | |
import os | |
# Retrieve the access token from the environment variable | |
ACCESS_TOKEN = os.getenv("HF_TOKEN") | |
print("Access token loaded.") | |
# Initialize the OpenAI client with the Hugging Face Inference API endpoint | |
client = OpenAI( | |
base_url="https://api-inference.huggingface.co/v1/", | |
api_key=ACCESS_TOKEN, | |
) | |
print("OpenAI client initialized.") | |
def respond( | |
user_message, | |
chat_history, | |
system_msg, | |
max_tokens, | |
temperature, | |
top_p, | |
frequency_penalty, | |
seed, | |
featured_model, | |
custom_model | |
): | |
""" | |
This function handles the chatbot response. It takes in: | |
- user_message: the user's newly typed message | |
- chat_history: the list of (user, assistant) message pairs | |
- system_msg: the system instruction or system-level context | |
- max_tokens: the maximum number of tokens to generate | |
- temperature: sampling temperature | |
- top_p: top-p (nucleus) sampling | |
- frequency_penalty: penalize repeated tokens in the output | |
- seed: a fixed seed for reproducibility; -1 means 'random' | |
- featured_model: the chosen model name from 'Featured Models' radio | |
- custom_model: the optional custom model that overrides the featured one if provided | |
""" | |
print(f"Received user message: {user_message}") | |
print(f"System message: {system_msg}") | |
print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}, Freq-Penalty: {frequency_penalty}, Seed: {seed}") | |
print(f"Featured model: {featured_model}") | |
print(f"Custom model: {custom_model}") | |
# Convert the seed to None if user set it to -1 (meaning random) | |
if seed == -1: | |
seed = None | |
# Decide which model to actually use | |
# If custom_model is non-empty, use that; otherwise use the chosen featured_model | |
model_to_use = custom_model.strip() if custom_model.strip() != "" else featured_model | |
# Provide a default fallback if for some reason both are empty | |
if model_to_use.strip() == "": | |
model_to_use = "meta-llama/Llama-3.3-70B-Instruct" | |
print(f"Model selected for inference: {model_to_use}") | |
# Construct the conversation history in the format required by HF's Inference API | |
messages = [] | |
if system_msg.strip(): | |
messages.append({"role": "system", "content": system_msg.strip()}) | |
# Add the conversation history | |
for user_text, assistant_text in chat_history: | |
if user_text: | |
messages.append({"role": "user", "content": user_text}) | |
if assistant_text: | |
messages.append({"role": "assistant", "content": assistant_text}) | |
# Add the new user message to the conversation | |
messages.append({"role": "user", "content": user_message}) | |
# We'll build the response token-by-token in a streaming loop | |
response_so_far = "" | |
print("Sending request to the Hugging Face Inference API...") | |
# Make the streaming request to the HF Inference API | |
try: | |
for resp_chunk in client.chat.completions.create( | |
model=model_to_use, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
frequency_penalty=frequency_penalty, | |
seed=seed, | |
messages=messages, | |
): | |
token_text = resp_chunk.choices[0].delta.content | |
response_so_far += token_text | |
# We yield back the updated message to display partial progress in the chatbot | |
yield response_so_far | |
except Exception as e: | |
# If there's an error, let's at least show it in the chat | |
error_text = f"[ERROR] {str(e)}" | |
print(error_text) | |
yield response_so_far + "\n\n" + error_text | |
print("Completed response generation.") | |
# | |
# BUILDING THE GRADIO INTERFACE BELOW | |
# | |
# List of featured models; adjust or replace these placeholders with real text-generation models | |
models_list = [ | |
"meta-llama/Llama-3.3-70B-Instruct", | |
"meta-llama/Llama-2-13B-chat-hf", | |
"bigscience/bloom", | |
"openlm-research/open_llama_7b", | |
"facebook/opt-6.7b", | |
"google/flan-t5-xxl", | |
] | |
def filter_models(search_term): | |
"""Filters the models_list by the given search_term and returns an update for the Radio component.""" | |
filtered = [m for m in models_list if search_term.lower() in m.lower()] | |
return gr.update(choices=filtered) | |
with gr.Blocks(theme="Nymbo/Nymbo_Theme_5") as demo: | |
gr.Markdown("# Serverless-TextGen-Hub (Enhanced)") | |
gr.Markdown("**A comprehensive UI for text generation with a featured-models dropdown and a custom override**.") | |
# We keep track of the conversation in a Gradio state variable (list of tuples) | |
chat_history = gr.State([]) | |
# Tabs for organization | |
with gr.Tab("Basic Settings"): | |
with gr.Row(): | |
with gr.Column(elem_id="prompt-container"): | |
# System Message | |
system_msg = gr.Textbox( | |
label="System message", | |
placeholder="Enter system-level instructions or context here.", | |
lines=2 | |
) | |
# Accordion for featured models | |
with gr.Accordion("Featured Models", open=True): | |
model_search = gr.Textbox( | |
label="Filter Models", | |
placeholder="Search for a featured model...", | |
lines=1 | |
) | |
# The radio that lists our featured models | |
model_radio = gr.Radio( | |
label="Select a featured model below", | |
choices=models_list, | |
value=models_list[0], # default | |
interactive=True | |
) | |
# Link the search box to update the model_radio choices | |
model_search.change(filter_models, inputs=model_search, outputs=model_radio) | |
# Custom Model | |
custom_model_box = gr.Textbox( | |
label="Custom Model (Optional)", | |
info="If provided, overrides the featured model above. e.g. 'meta-llama/Llama-3.3-70B-Instruct'", | |
placeholder="Your huggingface.co/username/model_name path" | |
) | |
with gr.Tab("Advanced Settings"): | |
with gr.Row(): | |
max_tokens_slider = gr.Slider( | |
minimum=1, | |
maximum=4096, | |
value=512, | |
step=1, | |
label="Max new tokens" | |
) | |
temperature_slider = gr.Slider( | |
minimum=0.1, | |
maximum=4.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
) | |
top_p_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-P" | |
) | |
with gr.Row(): | |
freq_penalty_slider = gr.Slider( | |
minimum=-2.0, | |
maximum=2.0, | |
value=0.0, | |
step=0.1, | |
label="Frequency Penalty" | |
) | |
seed_slider = gr.Slider( | |
minimum=-1, | |
maximum=65535, | |
value=-1, | |
step=1, | |
label="Seed (-1 for random)" | |
) | |
# Chat interface area: user input -> assistant output | |
with gr.Row(): | |
chatbot = gr.Chatbot( | |
label="TextGen Chat", | |
height=500 | |
) | |
# The user types a message here | |
user_input = gr.Textbox( | |
label="Your message", | |
placeholder="Type your text prompt here..." | |
) | |
# "Send" button triggers our respond() function, updates the chatbot | |
send_button = gr.Button("Send") | |
# A Clear Chat button to reset the conversation | |
clear_button = gr.Button("Clear Chat") | |
# Define how the Send button updates the state and chatbot | |
def user_submission(user_text, history): | |
""" | |
This function gets called first to add the user's message to the chat. | |
We return the updated chat_history with the user's message appended, | |
plus an empty string for the next user input box. | |
""" | |
if user_text.strip() == "": | |
return history, "" | |
# Append user message to chat | |
history = history + [(user_text, None)] | |
return history, "" | |
send_button.click( | |
fn=user_submission, | |
inputs=[user_input, chat_history], | |
outputs=[chat_history, user_input] | |
) | |
# Then we run the respond function (streaming) to generate the assistant message | |
def bot_response( | |
history, | |
system_msg, | |
max_tokens, | |
temperature, | |
top_p, | |
freq_penalty, | |
seed, | |
featured_model, | |
custom_model | |
): | |
""" | |
This function is called to generate the assistant's response | |
based on the conversation so far, system message, etc. | |
We do the streaming here. | |
""" | |
if not history: | |
yield history | |
# The last user message is in history[-1][0] | |
user_message = history[-1][0] if history else "" | |
# We pass everything to respond() generator | |
bot_stream = respond( | |
user_message=user_message, | |
chat_history=history[:-1], # all except the newly appended user message | |
system_msg=system_msg, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
frequency_penalty=freq_penalty, | |
seed=seed, | |
featured_model=featured_model, | |
custom_model=custom_model | |
) | |
partial_text = "" | |
for partial_text in bot_stream: | |
# We'll keep updating the last message in the conversation with partial_text | |
updated_history = history[:-1] + [(history[-1][0], partial_text)] | |
yield updated_history | |
send_button.click( | |
fn=bot_response, | |
inputs=[ | |
chat_history, | |
system_msg, | |
max_tokens_slider, | |
temperature_slider, | |
top_p_slider, | |
freq_penalty_slider, | |
seed_slider, | |
model_radio, | |
custom_model_box | |
], | |
outputs=chatbot | |
) | |
# Clear chat just resets the state | |
def clear_chat(): | |
return [], "" | |
clear_button.click( | |
fn=clear_chat, | |
inputs=[], | |
outputs=[chat_history, user_input] | |
) | |
# Launch the application | |
if __name__ == "__main__": | |
print("Launching the Serverless-TextGen-Hub with Featured Models & Custom Model override.") | |
demo.launch() |