La / main.py
Luisgust's picture
Update main.py
2c53650 verified
import os
import ast # Import ast for safer evaluation of literals
from typing import List, Tuple, Any # Import Any for better type hint after literal_eval
from fastapi import FastAPI, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
# Removed BaseModel as we are using Form
from text_generation import Client
from deep_translator import GoogleTranslator
# Ensure the HF_TOKEN environment variable is set
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("Please set the HF_TOKEN environment variable.")
# Model and API setup
model_id = 'NousResearch/Hermes-3-Llama-3.1-8B'
API_URL = "https://api-inference.huggingface.co/models/" + model_id
client = Client(
API_URL,
headers={"Authorization": f"Bearer {HF_TOKEN}"},
timeout=120 # Add a timeout for the client
)
# Correct End Of Text token for Llama 3 / 3.1
EOT_TOKEN = "<|eot_id|>"
# Expected header before the assistant's response starts in the Llama 3 format
ASSISTANT_HEADER = "<|start_header_id|>assistant<|end_header_id|>\n\n"
app = FastAPI()
# Allow CORS for your frontend application
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Change this to your frontend's URL in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant . Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""
# Updated get_prompt function using Llama 3.1 instruction format
def get_prompt(message: str, chat_history: List[Tuple[str, str]],
system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
"""
Formats the chat history and current message into the Llama 3.1 instruction format.
"""
prompt_parts = []
prompt_parts.append("<|begin_of_text|>")
# Add system prompt if provided
if system_prompt:
prompt_parts.append(f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}{EOT_TOKEN}")
# Add previous chat turns
for user_input, response in chat_history:
# Ensure inputs/responses are strings before including them
user_input_str = str(user_input).strip()
response_str = str(response).strip() if response is not None else "" # Handle potential None in history
prompt_parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{user_input_str}{EOT_TOKEN}")
# Ensure response is not empty before adding assistant turn
if response_str:
prompt_parts.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n{response_str}{EOT_TOKEN}")
# Add current user message and prepare for assistant response
message_str = str(message).strip()
prompt_parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{message_str}{EOT_TOKEN}")
prompt_parts.append(ASSISTANT_HEADER) # This is where the model starts generating
return "".join(prompt_parts)
# Keep app.post with Form data parameters
@app.post("/generate/")
async def generate_response(prompt: str = Form(...), history: str = Form(...)):
try:
# --- SAFELY Parse History ---
# Replace eval() with ast.literal_eval() for safety
# It can safely evaluate strings containing Python literals (like lists, tuples, strings, numbers, dicts, booleans, None)
try:
parsed_history: Any = ast.literal_eval(history)
# Basic validation to ensure it looks like the expected format
if not isinstance(parsed_history, list):
raise ValueError("History is not a list.")
# You could add more checks, e.g., if items are tuples of strings
chat_history: List[Tuple[str, str]] = [(str(u), str(a)) for u, a in parsed_history] # Ensure elements are strings
except (ValueError, SyntaxError, TypeError) as e:
# Catch errors if the history string is not a valid literal or not the right structure
raise HTTPException(status_code=400, detail=f"Invalid history format: {e}")
# --- End Safely Parse History ---
system_prompt = DEFAULT_SYSTEM_PROMPT # You could make this configurable too
message = prompt
prompt_text = get_prompt(message, chat_history, system_prompt)
generate_kwargs = dict(
max_new_tokens=1024,
do_sample=True,
top_p=0.9,
top_k=50,
temperature=0.1, # Keep temperature low for more predictable output
stop_sequences=[EOT_TOKEN], # Explicitly tell the API to stop at EOT
# return_full_text=False # Might need to experiment with this depending on API behavior
)
# Using generate (non-streaming) is simpler for post-processing
# If you need streaming, the logic below needs to be adapted for the stream loop
# Let's use generate first as it makes cleaning easier
output_obj = client.generate(prompt_text, **generate_kwargs)
output = output_obj.generated_text # Get the final generated text
# --- Post-processing the output ---
# The model *should* generate only the assistant response after the ASSISTANT_HEADER.
# However, sometimes leading whitespace or unexpected tokens can occur.
# Let's strip potential leading whitespace and the expected header if it's accidentally generated.
# Crucially, remove the ASSISTANT_HEADER that was part of the prompt structure
expected_start = ASSISTANT_HEADER.strip()
if output.startswith(expected_start):
output = output[len(expected_start):].strip()
else:
# Fallback: If it doesn't start with the expected header, just strip leading/trailing whitespace
output = output.strip()
# Remove the EOT token if it's still present at the end
if output.endswith(EOT_TOKEN):
output = output[:-len(EOT_TOKEN)].strip()
# --- End Post-processing ---
# Ensure the output is not empty before translating
if not output:
# If the model produced no output after cleaning, return a default or error
translated_output = "Error: Model did not produce a response or response was filtered."
else:
# Translate the cleaned response to Arabic
translator = GoogleTranslator(source='auto', target='ar')
translated_output = translator.translate(output)
return {"response": translated_output}
except HTTPException as he:
# Re-raise HTTPExceptions (like the 400 for invalid history)
raise he
except Exception as e:
# Log other errors for debugging purposes on the server
print(f"An error occurred during generation or translation: {e}")
raise HTTPException(status_code=500, detail=f"An internal error occurred: {e}")