Sarvam-M-Demo / app.py
KingNish's picture
Update app.py
b2e8189 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
import spaces
import time
# Load the model and tokenizer
model_name = "sarvamai/sarvam-m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
indicators = ["Thinking ⠋", "Thinking ⠙", "Thinking ⠹", "Thinking ⠸", "Thinking ⠼", "Thinking ⠴", "Thinking ⠦", "Thinking ⠧", "Thinking ⠇", "Thinking ⠏"]
@spaces.GPU(duration=120)
def generate_response(prompt, chat_history):
chat_history.append({"role": "user", "content": prompt})
yield chat_history, ""
print(chat_history)
# Preprocess chat history to include thinking tags
processed_chat_history = []
for message in chat_history:
# Skipping Thought Process in history
if message["role"] == "assistant":
metadata = message.get("metadata", {})
if isinstance(metadata, dict) and metadata.get("title", "").startswith("Thought"):
pass
else:
processed_chat_history.append(message)
else:
processed_chat_history.append(message)
text = tokenizer.apply_chat_template(processed_chat_history, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Use TextIteratorStreamer for streaming
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Conduct text generation with streaming
generation_kwargs = dict(
input_ids=model_inputs.input_ids,
max_new_tokens=8192,
streamer=streamer,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Initialize variables to capture reasoning content and main content
reasoning_content = ""
content = ""
reasoning_done = False
start_time = time.time()
chat_history.append({"role": "assistant", "content": reasoning_content, "metadata": {"title": "Thinking..."}})
indicator_index = 0
for new_text in streamer:
if "</think>" in new_text:
reasoning_done = True
thought_duration = time.time() - start_time
chat_history[-1]["metadata"] = {"title": f"Thought for {thought_duration:.2f} seconds"}
chat_history.append({"role": "assistant", "content": content})
if not reasoning_done:
# Update the thinking indicator
indicator_index = (indicator_index + 1) % len(indicators)
chat_history[-1]["metadata"] = {"title": indicators[indicator_index]}
reasoning_content += new_text
chat_history[-1]["content"] = reasoning_content
else:
content += new_text
chat_history[-1]["content"] = content
yield chat_history, ""
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Sarvam M Demo")
chatbot = gr.Chatbot(height=500, type="messages")
msg = gr.Textbox(label="Your Message")
msg.submit(generate_response, [msg, chatbot], [chatbot, msg])
if __name__ == "__main__":
demo.launch(mcp_server=True)