|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
HUGGINGFACE_MODEL_ID = "." |
|
|
|
|
|
|
|
TORCH_DTYPE = torch.float16 |
|
|
|
|
|
MAX_NEW_TOKENS = 512 |
|
DO_SAMPLE = True |
|
TEMPERATURE = 0.7 |
|
TOP_K = 50 |
|
TOP_P = 0.95 |
|
|
|
|
|
tokenizer = None |
|
model = None |
|
|
|
|
|
def load_model_and_tokenizer(): |
|
""" |
|
Loads the language model and tokenizer from Hugging Face Hub or a local path. |
|
This function will be called once when the Gradio app starts up. |
|
""" |
|
global tokenizer, model |
|
|
|
if tokenizer is not None and model is not None: |
|
print("Model and tokenizer already loaded.") |
|
return |
|
|
|
print(f"Loading tokenizer from: {HUGGINGFACE_MODEL_ID}") |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})") |
|
|
|
print(f"Loading model from: {HUGGINGFACE_MODEL_ID}...") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
HUGGINGFACE_MODEL_ID, |
|
torch_dtype=TORCH_DTYPE, |
|
device_map="auto" |
|
) |
|
model.eval() |
|
print("Model loaded successfully.") |
|
except Exception as e: |
|
print(f"Error loading model or tokenizer: {e}") |
|
print("Please ensure the model ID is correct and you have an internet connection for initial download, or the local path is valid.") |
|
tokenizer = None |
|
model = None |
|
raise RuntimeError("Failed to load model. Check your model ID/path and internet connection.") |
|
|
|
|
|
|
|
def generate_response( |
|
message: str, |
|
history: list |
|
) -> list: |
|
""" |
|
Generates a text response from the loaded model based on user input and chat history. |
|
""" |
|
global tokenizer, model |
|
|
|
|
|
if tokenizer is None or model is None: |
|
load_model_and_tokenizer() |
|
|
|
if tokenizer is None or model is None: |
|
|
|
|
|
history.append({"role": "user", "content": message}) |
|
history.append({"role": "assistant", "content": "Error: Chatbot model not loaded. Please check logs."}) |
|
return history |
|
|
|
|
|
|
|
messages = history |
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
try: |
|
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
except Exception as e: |
|
print(f"Error applying chat template: {e}") |
|
|
|
|
|
input_text = "" |
|
for item in history: |
|
if item["role"] == "user": |
|
input_text += f"User: {item['content']}\n" |
|
elif item["role"] == "assistant": |
|
input_text += f"Assistant: {item['content']}\n" |
|
input_text += f"User: {message}\nAssistant:" |
|
|
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
output_ids = model.generate( |
|
input_ids, |
|
max_new_tokens=MAX_NEW_TOKENS, |
|
do_sample=DO_SAMPLE, |
|
temperature=TEMPERATURE, |
|
top_k=TOP_K, |
|
top_p=TOP_P, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
generated_token_ids = output_ids[0][input_ids.shape[-1]:] |
|
generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip() |
|
|
|
|
|
|
|
history.append({"role": "assistant", "content": generated_text}) |
|
|
|
return history |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# Local Chatbot Powered by Hugging Face Transformers |
|
Type your message below and chat with the model loaded locally on your machine! |
|
""" |
|
) |
|
|
|
|
|
chatbot = gr.Chatbot(label="Conversation", type='messages') |
|
with gr.Row(): |
|
text_input = gr.Textbox( |
|
label="Your message", |
|
placeholder="Type your message here...", |
|
scale=4 |
|
) |
|
submit_button = gr.Button("Send", scale=1) |
|
|
|
|
|
|
|
|
|
submit_button.click( |
|
fn=generate_response, |
|
inputs=[text_input, chatbot], |
|
outputs=[chatbot], |
|
queue=True |
|
) |
|
text_input.submit( |
|
fn=generate_response, |
|
inputs=[text_input, chatbot], |
|
outputs=[chatbot], |
|
queue=True |
|
) |
|
|
|
|
|
def clear_chat(): |
|
|
|
|
|
return [], "" |
|
clear_button = gr.Button("Clear Chat") |
|
clear_button.click(clear_chat, inputs=None, outputs=[chatbot, text_input]) |
|
|
|
|
|
|
|
load_model_and_tokenizer() |
|
|
|
|
|
|
|
demo.queue().launch(server_name="0.0.0.0") |
|
|