|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""访问文本模型的命令行界面""" |
|
|
|
import argparse |
|
import os |
|
from openai import OpenAI |
|
import gradio as gr |
|
import random |
|
random.seed(42) |
|
|
|
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
SYSTEM_PROMPT = "你是一个有帮助的AI助手,能够回答用户的问题并提供帮助。" |
|
|
|
|
|
openai_api_key = "jiayi" |
|
|
|
aligner_port = 8013 |
|
base_port = 8011 |
|
aligner_api_base = f"http://0.0.0.0:{aligner_port}/v1" |
|
base_api_base = f"http://0.0.0.0:{base_port}/v1" |
|
|
|
|
|
|
|
|
|
|
|
aligner_model = "" |
|
base_model = "" |
|
|
|
aligner_client = OpenAI( |
|
api_key = openai_api_key, |
|
base_url = aligner_api_base, |
|
) |
|
|
|
base_client = OpenAI( |
|
api_key = openai_api_key, |
|
base_url = base_api_base, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TEXT_EXAMPLES = [ |
|
"介绍一下北京大学的历史", |
|
"解释一下什么是深度学习", |
|
"写一首关于春天的诗", |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def text_conversation(text: str, role: str = 'user'): |
|
"""创建单条文本消息""" |
|
return [{'role': role, 'content': text}] |
|
|
|
|
|
def question_answering(message: str, history: list): |
|
"""处理文本问答(流式输出)""" |
|
conversation = text_conversation(SYSTEM_PROMPT, 'system') |
|
|
|
|
|
for past_user_msg, past_bot_msg in history: |
|
if past_user_msg: |
|
conversation.extend(text_conversation(past_user_msg, 'user')) |
|
if past_bot_msg: |
|
conversation.extend(text_conversation(past_bot_msg, 'assistant')) |
|
|
|
|
|
current_question = message |
|
conversation.extend(text_conversation(current_question)) |
|
|
|
|
|
stream = base_client.chat.completions.create( |
|
model=base_model, |
|
stream=True, |
|
messages=conversation, |
|
) |
|
|
|
|
|
total_answer = "" |
|
base_section = "🌟 **原始回答:**\n" |
|
total_answer += base_section |
|
|
|
base_answer = "" |
|
yield total_answer |
|
for chunk in stream: |
|
if chunk.choices[0].delta.content is not None: |
|
base_answer += chunk.choices[0].delta.content |
|
total_answer += chunk.choices[0].delta.content |
|
yield f"```bash\n{base_section}{base_answer}\n```" |
|
|
|
|
|
aligner_section = "\n**Aligner 修正中...**\n\n🌟 **修正后回答:**\n" |
|
|
|
|
|
total_answer = f"```bash\n{base_section}{base_answer}\n```{aligner_section}" |
|
yield total_answer |
|
|
|
aligner_conversation = text_conversation(SYSTEM_PROMPT,'system') |
|
aligner_current_question = f'##Question: {current_question}\n##Answer: {base_answer}\n##Correction: ' |
|
aligner_conversation.extend(text_conversation(aligner_current_question)) |
|
aligner_stream = aligner_client.chat.completions.create( |
|
model=aligner_model, |
|
stream=True, |
|
messages=aligner_conversation, |
|
) |
|
|
|
aligner_answer = "" |
|
for chunk in aligner_stream: |
|
if chunk.choices[0].delta.content is not None: |
|
aligner_answer += chunk.choices[0].delta.content |
|
aligner_answer = aligner_answer.replace('##CORRECTION:', '') |
|
yield f"```bash\n{base_section}{base_answer}\n```{aligner_section}{aligner_answer}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--port", type=int, default=7860, help="Gradio服务端口") |
|
parser.add_argument("--share", default='True',action="store_true", help="是否创建公共链接") |
|
parser.add_argument("--api-only", default='False',action="store_true", help="只输出Python API调用示例") |
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
|
|
iface = gr.ChatInterface( |
|
fn=question_answering, |
|
title='Aligner', |
|
description='网络安全 Aligner', |
|
examples=TEXT_EXAMPLES, |
|
theme=gr.themes.Soft( |
|
text_size='lg', |
|
spacing_size='lg', |
|
radius_size='lg', |
|
), |
|
) |
|
|
|
iface.launch(server_port=args.port, share=args.share) |