Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import spaces | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
| from threading import Thread | |
| import torch | |
| import time | |
| # Set environment variables | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| # Apollo system prompt | |
| SYSTEM_PROMPT = "You are Apollo, a multilingual medical model. You communicate with people and assist them." | |
| LICENSE = """ | |
| <div style="font-family: monospace; white-space: pre; margin-top: 20px; line-height: 1.2;"> | |
| @misc{wang2024apollo, | |
| title={Apollo: Lightweight Multilingual Medical LLMs towards Democratizing Medical AI to 6B People}, | |
| author={Xidong Wang and Nuo Chen and Junyin Chen and Yan Hu and Yidong Wang and Xiangbo Wu and Anningzhe Gao and Xiang Wan and Haizhou Li and Benyou Wang}, | |
| year={2024}, | |
| eprint={2403.03640}, | |
| archivePrefix={arXiv}, | |
| primaryClass={cs.CL} | |
| } | |
| @misc{zheng2024efficientlydemocratizingmedicalllms, | |
| title={Efficiently Democratizing Medical LLMs for 50 Languages via a Mixture of Language Family Experts}, | |
| author={Guorui Zheng and Xidong Wang and Juhao Liang and Nuo Chen and Yuping Zheng and Benyou Wang}, | |
| year={2024}, | |
| eprint={2410.10626}, | |
| archivePrefix={arXiv}, | |
| primaryClass={cs.CL}, | |
| url={https://arxiv.org/abs/2410.10626}, | |
| } | |
| </div> | |
| """ | |
| # Apollo model options | |
| APOLLO_MODELS = { | |
| "Apollo": [ | |
| "FreedomIntelligence/Apollo-7B", | |
| "FreedomIntelligence/Apollo-6B", | |
| "FreedomIntelligence/Apollo-2B", | |
| "FreedomIntelligence/Apollo-0.5B", | |
| ], | |
| "Apollo2": [ | |
| "FreedomIntelligence/Apollo2-7B", | |
| "FreedomIntelligence/Apollo2-3.8B", | |
| "FreedomIntelligence/Apollo2-2B", | |
| ], | |
| "Apollo-MoE": [ | |
| "FreedomIntelligence/Apollo-MoE-7B", | |
| "FreedomIntelligence/Apollo-MoE-1.5B", | |
| "FreedomIntelligence/Apollo-MoE-0.5B", | |
| ] | |
| } | |
| # CSS styles | |
| css = """ | |
| h1 { | |
| text-align: center; | |
| display: block; | |
| } | |
| .gradio-container { | |
| max-width: 1200px; | |
| margin: auto; | |
| } | |
| """ | |
| # Global variables to store currently loaded model and tokenizer | |
| current_model = None | |
| current_tokenizer = None | |
| current_model_path = None | |
| def load_model(model_path, progress=gr.Progress()): | |
| """Load the selected model and tokenizer""" | |
| global current_model, current_tokenizer, current_model_path | |
| # If the same model is already loaded, don't reload it | |
| if current_model_path == model_path and current_model is not None: | |
| return "Model already loaded, no need to reload." | |
| # Clean up previously loaded model (if any) | |
| if current_model is not None: | |
| del current_model | |
| del current_tokenizer | |
| torch.cuda.empty_cache() | |
| progress(0.1, desc=f"Starting to load model {model_path}...") | |
| try: | |
| progress(0.3, desc="Loading tokenizer...") | |
| config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | |
| if 'MoE' in model_path: | |
| config_moe = config | |
| config_moe["auto_map"]["AutoConfig"] = "./configuration_upcycling_qwen2_moe.UpcyclingQwen2MoeConfig" | |
| config_moe["auto_map"]["AutoModelForCausalLM"] = "./modeling_upcycling_qwen2_moe.UpcyclingQwen2MoeForCausalLM" | |
| current_tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False,trust_remote_code=True) | |
| progress(0.5, desc="Loading model...") | |
| current_model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| config=config_moe if 'MoE' in model_path else config, | |
| trust_remote_code=True | |
| ) | |
| current_model_path = model_path | |
| progress(1.0, desc="Model loading complete!") | |
| return f"Model {model_path} successfully loaded." | |
| except Exception as e: | |
| progress(1.0, desc="Model loading failed!") | |
| return f"Model loading failed: {str(e)}" | |
| def generate_response_non_streaming(instruction, model_name, temperature=0.7, max_tokens=1024): | |
| """Generate a response from the Apollo model (non-streaming)""" | |
| global current_model, current_tokenizer, current_model_path | |
| print("instruction:",instruction) | |
| # If model is not yet loaded, load it first | |
| if current_model_path != model_name or current_model is None: | |
| load_message = load_model(model_name) | |
| if "failed" in load_message.lower(): | |
| return load_message | |
| try: | |
| # 直接使用简单的提示格式,不使用模型的聊天模板 | |
| prompt = f"User:{instruction}\nAssistant:" | |
| print("prompt:",prompt) | |
| chat_input = current_tokenizer.encode(prompt, return_tensors="pt").to(current_model.device) | |
| # 生成响应 | |
| output = current_model.generate( | |
| input_ids=chat_input, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| do_sample=(temperature > 0), | |
| eos_token_id=current_tokenizer.eos_token_id # 使用<|endoftext|>作为停止标记 | |
| ) | |
| # 解码并返回生成的文本 | |
| generated_text = current_tokenizer.decode(output[0][len(chat_input[0]):], skip_special_tokens=True) | |
| print("generated_text:",generated_text) | |
| return generated_text | |
| except Exception as e: | |
| return f"生成响应时出错: {str(e)}" | |
| # try: | |
| # # 检查模型是否有聊天模板 | |
| # if hasattr(current_tokenizer, 'chat_template') and current_tokenizer.chat_template: | |
| # # 使用模型的聊天模板 | |
| # messages = [ | |
| # {"role": "system", "content": SYSTEM_PROMPT}, | |
| # {"role": "user", "content": instruction} | |
| # ] | |
| # # 使用模型的聊天模板格式化输入 | |
| # chat_input = current_tokenizer.apply_chat_template( | |
| # messages, | |
| # tokenize=True, | |
| # return_tensors="pt" | |
| # ).to(current_model.device) | |
| # else: | |
| # # 使用指定的提示格式 | |
| # prompt = f"User:{instruction}\nAssistant:" | |
| # chat_input = current_tokenizer.encode(prompt, return_tensors="pt").to(current_model.device) | |
| # # 获取<|endoftext|>的token id,用于停止生成 | |
| # eos_token_id = current_tokenizer.eos_token_id | |
| # # 生成响应 | |
| # output = current_model.generate( | |
| # input_ids=chat_input, | |
| # max_new_tokens=max_tokens, | |
| # temperature=temperature, | |
| # do_sample=(temperature > 0), | |
| # eos_token_id=current_tokenizer.eos_token_id # 使用<|endoftext|>作为停止标记 | |
| # ) | |
| # # 解码并返回生成的文本 | |
| # generated_text = current_tokenizer.decode(output[0][len(chat_input[0]):], skip_special_tokens=True) | |
| # return generated_text | |
| # except Exception as e: | |
| # return f"生成响应时出错: {str(e)}" | |
| def update_chat_with_response(chatbot, instruction, model_name, temperature, max_tokens): | |
| """Updates the chatbot with non-streaming response""" | |
| global current_model, current_tokenizer, current_model_path | |
| # If model is not yet loaded, load it first | |
| if current_model_path != model_name or current_model is None: | |
| load_result = load_model(model_name) | |
| if "failed" in load_result.lower(): | |
| new_chat = list(chatbot) | |
| new_chat[-1] = (instruction, load_result) | |
| return new_chat | |
| # Generate response using the non-streaming function | |
| response = generate_response_non_streaming(instruction, model_name, temperature, max_tokens) | |
| # Create a copy of the current chatbot and add the response | |
| new_chat = list(chatbot) | |
| new_chat[-1] = (instruction, response) | |
| return new_chat | |
| def on_model_series_change(model_series): | |
| """Update available model list based on selected model series""" | |
| if model_series in APOLLO_MODELS: | |
| return gr.update(choices=APOLLO_MODELS[model_series], value=APOLLO_MODELS[model_series][0]) | |
| return gr.update(choices=[], value=None) | |
| # Create Gradio interface | |
| with gr.Blocks(css=css) as demo: | |
| # Title and description | |
| favicon = "🩺" | |
| gr.Markdown( | |
| f"""# {favicon} Apollo Playground | |
| This is a demo of the multilingual medical model series **[Apollo](https://github.com/FreedomIntelligence/Apollo)** made by **[FreedomIntelligence](https://huggingface.co/FreedomIntelligence)**. | |
| [Apollo1](https://arxiv.org/abs/2403.03640) supports 6 languages. [Apollo2](https://arxiv.org/abs/2410.10626) and [Apollo-MOE](https://arxiv.org/abs/2410.10626) supports 50 languages. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Model selection controls | |
| model_series = gr.Dropdown( | |
| choices=list(APOLLO_MODELS.keys()), | |
| value="Apollo", | |
| label="Select Model Series", | |
| info="First choose Apollo, Apollo2 or Apollo-MoE" | |
| ) | |
| model_name = gr.Dropdown( | |
| choices=APOLLO_MODELS["Apollo"], | |
| value=APOLLO_MODELS["Apollo"][0], | |
| label="Select Model Size", | |
| info="Select the specific model size based on the chosen model series" | |
| ) | |
| # Parameter settings | |
| with gr.Accordion("Generation Parameters", open=False): | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.05, | |
| label="Temperature" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=128, | |
| maximum=2048, | |
| value=1024, | |
| step=32, | |
| label="Maximum Tokens" | |
| ) | |
| # Load model button | |
| load_button = gr.Button("Load Model") | |
| model_status = gr.Textbox(label="Model Status", value="No model loaded yet") | |
| with gr.Column(scale=2): | |
| # Chat interface | |
| chatbot = gr.Chatbot(label="Conversation", height=500, value=[]) # Initialize with empty list | |
| user_input = gr.Textbox( | |
| label="Input Medical Question", | |
| placeholder="Example: What are the symptoms of hypertension? 高血压有哪些症状?", | |
| lines=3 | |
| ) | |
| submit_button = gr.Button("Submit") | |
| clear_button = gr.Button("Clear Chat") | |
| # Event handling | |
| # Update model selection when model series changes | |
| model_series.change( | |
| fn=on_model_series_change, | |
| inputs=model_series, | |
| outputs=model_name | |
| ) | |
| # Load model | |
| load_button.click( | |
| fn=load_model, | |
| inputs=model_name, | |
| outputs=model_status | |
| ) | |
| # Bind message submission | |
| def process_message(message, chat_history): | |
| """Process user message and generate response""" | |
| if message.strip() == "": | |
| return "", chat_history | |
| # Add user message to chat history | |
| chat_history = list(chat_history) | |
| chat_history.append((message, None)) | |
| # Generate response | |
| response = generate_response_non_streaming(message, model_name.value, temperature.value, max_tokens.value) | |
| # Add response to chat history | |
| chat_history[-1] = (message, response) | |
| return "", chat_history | |
| submit_event = user_input.submit( | |
| fn=process_message, | |
| inputs=[user_input, chatbot], | |
| outputs=[user_input, chatbot] | |
| ) | |
| submit_button.click( | |
| fn=process_message, | |
| inputs=[user_input, chatbot], | |
| outputs=[user_input, chatbot] | |
| ) | |
| # Clear chat | |
| clear_button.click( | |
| fn=lambda: [], | |
| outputs=chatbot | |
| ) | |
| # # Handle message submission | |
| # def user_message_submitted(message, chat_history): | |
| # """Handle user submitted message""" | |
| # # Ensure chat_history is a list | |
| # if chat_history is None: | |
| # chat_history = [] | |
| # if message.strip() == "": | |
| # return "", chat_history | |
| # # Add user message to chat history | |
| # chat_history = list(chat_history) | |
| # chat_history.append((message, None)) | |
| # return "", chat_history | |
| # # Bind message submission | |
| # submit_event = user_input.submit( | |
| # fn=user_message_submitted, | |
| # inputs=[user_input, chatbot], | |
| # outputs=[user_input, chatbot] | |
| # ).then( | |
| # fn=update_chat_with_response, | |
| # inputs=[chatbot, user_input, model_name, temperature, max_tokens], | |
| # outputs=chatbot | |
| # ) | |
| # submit_button.click( | |
| # fn=user_message_submitted, | |
| # inputs=[user_input, chatbot], | |
| # outputs=[user_input, chatbot] | |
| # ).then( | |
| # fn=update_chat_with_response, | |
| # inputs=[chatbot, user_input, model_name, temperature, max_tokens], | |
| # outputs=chatbot | |
| # ) | |
| # # Clear chat | |
| # clear_button.click( | |
| # fn=lambda: [], | |
| # outputs=chatbot | |
| # ) | |
| examples = [ | |
| ["Últimamente tengo la tensión un poco alta, ¿cómo debo adaptar mis hábitos?"], | |
| ["What are the common side effects of metformin?"], | |
| ["中医和西医在治疗高血压方面有什么不同的观点?"], | |
| ["मेरा सिर दर्द कर रहा है, मुझे क्या करना चाहिए? "], | |
| ["Comment savoir si je suis diabétique ?"], | |
| ["ما الدواء الذي يمكنني تناوله إذا لم أستطع النوم ليلاً؟"] | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=user_input | |
| ) | |
| #gr.HTML(LICENSE) | |
| if __name__ == "__main__": | |
| demo.launch() | |