import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import torch from threading import Thread import spaces import time # Load the model and tokenizer model_name = "sarvamai/sarvam-m" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") indicators = ["Thinking ⠋", "Thinking ⠙", "Thinking ⠹", "Thinking ⠸", "Thinking ⠼", "Thinking ⠴", "Thinking ⠦", "Thinking ⠧", "Thinking ⠇", "Thinking ⠏"] @spaces.GPU(duration=120) def generate_response(prompt, chat_history): chat_history.append({"role": "user", "content": prompt}) yield chat_history, "" print(chat_history) # Preprocess chat history to include thinking tags processed_chat_history = [] for message in chat_history: # Skipping Thought Process in history if message["role"] == "assistant": metadata = message.get("metadata", {}) if isinstance(metadata, dict) and metadata.get("title", "").startswith("Thought"): pass else: processed_chat_history.append(message) else: processed_chat_history.append(message) text = tokenizer.apply_chat_template(processed_chat_history, tokenize=False, add_generation_prompt=True) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) # Use TextIteratorStreamer for streaming streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # Conduct text generation with streaming generation_kwargs = dict( input_ids=model_inputs.input_ids, max_new_tokens=8192, streamer=streamer, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Initialize variables to capture reasoning content and main content reasoning_content = "" content = "" reasoning_done = False start_time = time.time() chat_history.append({"role": "assistant", "content": reasoning_content, "metadata": {"title": "Thinking..."}}) indicator_index = 0 for new_text in streamer: if "" in new_text: reasoning_done = True thought_duration = time.time() - start_time chat_history[-1]["metadata"] = {"title": f"Thought for {thought_duration:.2f} seconds"} chat_history.append({"role": "assistant", "content": content}) if not reasoning_done: # Update the thinking indicator indicator_index = (indicator_index + 1) % len(indicators) chat_history[-1]["metadata"] = {"title": indicators[indicator_index]} reasoning_content += new_text chat_history[-1]["content"] = reasoning_content else: content += new_text chat_history[-1]["content"] = content yield chat_history, "" # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("# Sarvam M Demo") chatbot = gr.Chatbot(height=500, type="messages") msg = gr.Textbox(label="Your Message") msg.submit(generate_response, [msg, chatbot], [chatbot, msg]) if __name__ == "__main__": demo.launch(mcp_server=True)