|
import gradio as gr |
|
import torch |
|
from transformers import pipeline |
|
|
|
def initialize_model(): |
|
"""Initialize the text generation pipeline with device detection""" |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
|
|
try: |
|
generator = pipeline( |
|
"text-generation", |
|
model="akhaliq/gemma-3-270m-gradio-coder", |
|
device=device |
|
) |
|
return generator |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
|
|
if device == "cuda": |
|
print("Falling back to CPU...") |
|
generator = pipeline( |
|
"text-generation", |
|
model="akhaliq/gemma-3-270m-gradio-coder", |
|
device="cpu" |
|
) |
|
return generator |
|
raise e |
|
|
|
|
|
print("Loading model...") |
|
generator = initialize_model() |
|
print("Model loaded successfully!") |
|
|
|
def chat_response(message, history): |
|
"""Generate response for the chatbot""" |
|
try: |
|
|
|
input_message = [{"role": "user", "content": message}] |
|
|
|
|
|
output = generator( |
|
input_message, |
|
max_new_tokens=128, |
|
return_full_text=False, |
|
do_sample=True, |
|
temperature=0.7, |
|
pad_token_id=generator.tokenizer.eos_token_id |
|
)[0] |
|
|
|
response = output["generated_text"] |
|
return response |
|
|
|
except Exception as e: |
|
return f"Sorry, I encountered an error: {str(e)}" |
|
|
|
|
|
def create_chatbot(): |
|
"""Create and launch the Gradio chatbot interface""" |
|
|
|
|
|
css = """ |
|
.gradio-container { |
|
max-width: 800px !important; |
|
margin: auto !important; |
|
} |
|
.chat-message { |
|
padding: 10px !important; |
|
margin: 5px !important; |
|
border-radius: 10px !important; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css, title="AI Chatbot") as demo: |
|
gr.Markdown("# 🤖 AI Chatbot") |
|
gr.Markdown("*Powered by Gemma-3-270m model via Transformers*") |
|
|
|
chatbot = gr.Chatbot( |
|
height=500, |
|
bubble_full_width=False, |
|
show_label=False |
|
) |
|
|
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
placeholder="Type your message here...", |
|
show_label=False, |
|
scale=4 |
|
) |
|
send_btn = gr.Button("Send", scale=1, variant="primary") |
|
clear_btn = gr.Button("Clear", scale=1) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
"If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?", |
|
"What's the most important lesson you've learned in life?", |
|
"How do you think AI will change the world in the next 10 years?", |
|
"What would you do if you had unlimited resources for one day?" |
|
], |
|
inputs=msg |
|
) |
|
|
|
def respond(message, chat_history): |
|
if not message.strip(): |
|
return chat_history, "" |
|
|
|
|
|
bot_message = chat_response(message, chat_history) |
|
|
|
|
|
chat_history.append((message, bot_message)) |
|
return chat_history, "" |
|
|
|
def clear_chat(): |
|
return [], "" |
|
|
|
|
|
msg.submit(respond, [msg, chatbot], [chatbot, msg]) |
|
send_btn.click(respond, [msg, chatbot], [chatbot, msg]) |
|
clear_btn.click(clear_chat, None, [chatbot, msg]) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
print("Creating Gradio interface...") |
|
demo = create_chatbot() |
|
|
|
print("Starting Gradio server...") |
|
demo.launch( |
|
share=False, |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_error=True |
|
) |
|
|