雷娃
add install.sh
7d20d9a
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
import gradio as gr
import json
import subprocess
import os
def install_vllm_from_patch():
script_path = "./install.sh"
if not os.path.exists(script_path):
print(f"Error: install.sh is not exist!")
return False
try:
print(f"begin run install.sh")
result = subprocess.run(
["bash", script_path],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text = True,
timeout = 300
)
print(f"result: {result}")
return True
except Exception as e:
print(f"Error: {str(e)}")
return False
#install vllm from patch file
#install_vllm_from_patch()
# load model and tokenizer
model_name = "inclusionAI/Ling-mini-2.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
).eval()
def respond(
message,
history: list[dict[str, str]],
system_message,
max_tokens,
# temperature,
# top_p
):
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
#client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
if len(system_message) == 0:
system_message = "## 你是谁\n\n我是百灵(Ling),一个由蚂蚁集团(Ant Group) 开发的AI智能助手"
messages = [{"role": "system", "content": system_message}]
messages.extend(history)
messages.append({"role": "user", "content": message})
print(f"system_prompt: {json.dumps(messages, ensure_ascii=False, indent=2)}")
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
model_inputs = tokenizer([text], return_tensors="pt", return_token_type_ids=False).to(model.device)
print(f"max_new_tokens={max_tokens}")
model_inputs.update(
dict(max_new_tokens=max_tokens,
streamer = streamer,
# temperature = 0.7,
# top_p = 1,
# presence_penalty = 1.5,
)
)
# Start a separate thread for model generation to allow streaming output
thread = Thread(
target=model.generate,
kwargs=model_inputs,
)
thread.start()
# Accumulate and yield text tokens as they are generated
acc_text = ""
for text_token in streamer:
acc_text += text_token # Append the generated token to the accumulated text
yield acc_text # Yield the accumulated text
# Ensure the generation thread completes
thread.join()
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
chatbot = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Textbox(value="", label="System message"),
gr.Slider(minimum=1, maximum=32000, value=512, step=1, label="Max new tokens"),
# gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
# gr.Slider(
# minimum=0.1,
# maximum=1.0,
# value=0.95,
# step=0.05,
# label="Top-p (nucleus sampling)",
# ),
],
)
with gr.Blocks() as demo:
# with gr.Sidebar():
# gr.LoginButton()
chatbot.render()
if __name__ == "__main__":
demo.launch()