|
import os |
|
import ast |
|
from typing import List, Tuple, Any |
|
from fastapi import FastAPI, Form, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
from text_generation import Client |
|
from deep_translator import GoogleTranslator |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
if HF_TOKEN is None: |
|
raise ValueError("Please set the HF_TOKEN environment variable.") |
|
|
|
|
|
model_id = 'NousResearch/Hermes-3-Llama-3.1-8B' |
|
API_URL = "https://api-inference.huggingface.co/models/" + model_id |
|
|
|
client = Client( |
|
API_URL, |
|
headers={"Authorization": f"Bearer {HF_TOKEN}"}, |
|
timeout=120 |
|
) |
|
|
|
|
|
EOT_TOKEN = "<|eot_id|>" |
|
|
|
ASSISTANT_HEADER = "<|start_header_id|>assistant<|end_header_id|>\n\n" |
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
DEFAULT_SYSTEM_PROMPT = """\ |
|
You are a helpful, respectful and honest assistant . Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\ |
|
""" |
|
|
|
|
|
def get_prompt(message: str, chat_history: List[Tuple[str, str]], |
|
system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str: |
|
""" |
|
Formats the chat history and current message into the Llama 3.1 instruction format. |
|
""" |
|
prompt_parts = [] |
|
prompt_parts.append("<|begin_of_text|>") |
|
|
|
|
|
if system_prompt: |
|
prompt_parts.append(f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}{EOT_TOKEN}") |
|
|
|
|
|
for user_input, response in chat_history: |
|
|
|
user_input_str = str(user_input).strip() |
|
response_str = str(response).strip() if response is not None else "" |
|
|
|
prompt_parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{user_input_str}{EOT_TOKEN}") |
|
|
|
if response_str: |
|
prompt_parts.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n{response_str}{EOT_TOKEN}") |
|
|
|
|
|
message_str = str(message).strip() |
|
prompt_parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{message_str}{EOT_TOKEN}") |
|
prompt_parts.append(ASSISTANT_HEADER) |
|
|
|
return "".join(prompt_parts) |
|
|
|
|
|
@app.post("/generate/") |
|
async def generate_response(prompt: str = Form(...), history: str = Form(...)): |
|
try: |
|
|
|
|
|
|
|
try: |
|
parsed_history: Any = ast.literal_eval(history) |
|
|
|
|
|
if not isinstance(parsed_history, list): |
|
raise ValueError("History is not a list.") |
|
|
|
|
|
chat_history: List[Tuple[str, str]] = [(str(u), str(a)) for u, a in parsed_history] |
|
|
|
except (ValueError, SyntaxError, TypeError) as e: |
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid history format: {e}") |
|
|
|
|
|
system_prompt = DEFAULT_SYSTEM_PROMPT |
|
message = prompt |
|
|
|
prompt_text = get_prompt(message, chat_history, system_prompt) |
|
|
|
generate_kwargs = dict( |
|
max_new_tokens=1024, |
|
do_sample=True, |
|
top_p=0.9, |
|
top_k=50, |
|
temperature=0.1, |
|
stop_sequences=[EOT_TOKEN], |
|
|
|
) |
|
|
|
|
|
|
|
|
|
output_obj = client.generate(prompt_text, **generate_kwargs) |
|
output = output_obj.generated_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
expected_start = ASSISTANT_HEADER.strip() |
|
if output.startswith(expected_start): |
|
output = output[len(expected_start):].strip() |
|
else: |
|
|
|
output = output.strip() |
|
|
|
|
|
|
|
if output.endswith(EOT_TOKEN): |
|
output = output[:-len(EOT_TOKEN)].strip() |
|
|
|
|
|
|
|
if not output: |
|
|
|
translated_output = "Error: Model did not produce a response or response was filtered." |
|
else: |
|
|
|
translator = GoogleTranslator(source='auto', target='ar') |
|
translated_output = translator.translate(output) |
|
|
|
return {"response": translated_output} |
|
|
|
except HTTPException as he: |
|
|
|
raise he |
|
except Exception as e: |
|
|
|
print(f"An error occurred during generation or translation: {e}") |
|
raise HTTPException(status_code=500, detail=f"An internal error occurred: {e}") |