Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import torch | |
from threading import Thread | |
import spaces | |
import time | |
# Load the model and tokenizer | |
model_name = "sarvamai/sarvam-m" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") | |
indicators = ["Thinking ⠋", "Thinking ⠙", "Thinking ⠹", "Thinking ⠸", "Thinking ⠼", "Thinking ⠴", "Thinking ⠦", "Thinking ⠧", "Thinking ⠇", "Thinking ⠏"] | |
def generate_response(prompt, chat_history): | |
chat_history.append({"role": "user", "content": prompt}) | |
yield chat_history, "" | |
print(chat_history) | |
# Preprocess chat history to include thinking tags | |
processed_chat_history = [] | |
for message in chat_history: | |
# Skipping Thought Process in history | |
if message["role"] == "assistant": | |
metadata = message.get("metadata", {}) | |
if isinstance(metadata, dict) and metadata.get("title", "").startswith("Thought"): | |
pass | |
else: | |
processed_chat_history.append(message) | |
else: | |
processed_chat_history.append(message) | |
text = tokenizer.apply_chat_template(processed_chat_history, tokenize=False, add_generation_prompt=True) | |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
# Use TextIteratorStreamer for streaming | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# Conduct text generation with streaming | |
generation_kwargs = dict( | |
input_ids=model_inputs.input_ids, | |
max_new_tokens=8192, | |
streamer=streamer, | |
) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Initialize variables to capture reasoning content and main content | |
reasoning_content = "" | |
content = "" | |
reasoning_done = False | |
start_time = time.time() | |
chat_history.append({"role": "assistant", "content": reasoning_content, "metadata": {"title": "Thinking..."}}) | |
indicator_index = 0 | |
for new_text in streamer: | |
if "</think>" in new_text: | |
reasoning_done = True | |
thought_duration = time.time() - start_time | |
chat_history[-1]["metadata"] = {"title": f"Thought for {thought_duration:.2f} seconds"} | |
chat_history.append({"role": "assistant", "content": content}) | |
if not reasoning_done: | |
# Update the thinking indicator | |
indicator_index = (indicator_index + 1) % len(indicators) | |
chat_history[-1]["metadata"] = {"title": indicators[indicator_index]} | |
reasoning_content += new_text | |
chat_history[-1]["content"] = reasoning_content | |
else: | |
content += new_text | |
chat_history[-1]["content"] = content | |
yield chat_history, "" | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Sarvam M Demo") | |
chatbot = gr.Chatbot(height=500, type="messages") | |
msg = gr.Textbox(label="Your Message") | |
msg.submit(generate_response, [msg, chatbot], [chatbot, msg]) | |
if __name__ == "__main__": | |
demo.launch(mcp_server=True) |