import os import ast # Import ast for safer evaluation of literals from typing import List, Tuple, Any # Import Any for better type hint after literal_eval from fastapi import FastAPI, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware # Removed BaseModel as we are using Form from text_generation import Client from deep_translator import GoogleTranslator # Ensure the HF_TOKEN environment variable is set HF_TOKEN = os.environ.get("HF_TOKEN") if HF_TOKEN is None: raise ValueError("Please set the HF_TOKEN environment variable.") # Model and API setup model_id = 'NousResearch/Hermes-3-Llama-3.1-8B' API_URL = "https://api-inference.huggingface.co/models/" + model_id client = Client( API_URL, headers={"Authorization": f"Bearer {HF_TOKEN}"}, timeout=120 # Add a timeout for the client ) # Correct End Of Text token for Llama 3 / 3.1 EOT_TOKEN = "<|eot_id|>" # Expected header before the assistant's response starts in the Llama 3 format ASSISTANT_HEADER = "<|start_header_id|>assistant<|end_header_id|>\n\n" app = FastAPI() # Allow CORS for your frontend application app.add_middleware( CORSMiddleware, allow_origins=["*"], # Change this to your frontend's URL in production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) DEFAULT_SYSTEM_PROMPT = """\ You are a helpful, respectful and honest assistant . Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\ """ # Updated get_prompt function using Llama 3.1 instruction format def get_prompt(message: str, chat_history: List[Tuple[str, str]], system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str: """ Formats the chat history and current message into the Llama 3.1 instruction format. """ prompt_parts = [] prompt_parts.append("<|begin_of_text|>") # Add system prompt if provided if system_prompt: prompt_parts.append(f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}{EOT_TOKEN}") # Add previous chat turns for user_input, response in chat_history: # Ensure inputs/responses are strings before including them user_input_str = str(user_input).strip() response_str = str(response).strip() if response is not None else "" # Handle potential None in history prompt_parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{user_input_str}{EOT_TOKEN}") # Ensure response is not empty before adding assistant turn if response_str: prompt_parts.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n{response_str}{EOT_TOKEN}") # Add current user message and prepare for assistant response message_str = str(message).strip() prompt_parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{message_str}{EOT_TOKEN}") prompt_parts.append(ASSISTANT_HEADER) # This is where the model starts generating return "".join(prompt_parts) # Keep app.post with Form data parameters @app.post("/generate/") async def generate_response(prompt: str = Form(...), history: str = Form(...)): try: # --- SAFELY Parse History --- # Replace eval() with ast.literal_eval() for safety # It can safely evaluate strings containing Python literals (like lists, tuples, strings, numbers, dicts, booleans, None) try: parsed_history: Any = ast.literal_eval(history) # Basic validation to ensure it looks like the expected format if not isinstance(parsed_history, list): raise ValueError("History is not a list.") # You could add more checks, e.g., if items are tuples of strings chat_history: List[Tuple[str, str]] = [(str(u), str(a)) for u, a in parsed_history] # Ensure elements are strings except (ValueError, SyntaxError, TypeError) as e: # Catch errors if the history string is not a valid literal or not the right structure raise HTTPException(status_code=400, detail=f"Invalid history format: {e}") # --- End Safely Parse History --- system_prompt = DEFAULT_SYSTEM_PROMPT # You could make this configurable too message = prompt prompt_text = get_prompt(message, chat_history, system_prompt) generate_kwargs = dict( max_new_tokens=1024, do_sample=True, top_p=0.9, top_k=50, temperature=0.1, # Keep temperature low for more predictable output stop_sequences=[EOT_TOKEN], # Explicitly tell the API to stop at EOT # return_full_text=False # Might need to experiment with this depending on API behavior ) # Using generate (non-streaming) is simpler for post-processing # If you need streaming, the logic below needs to be adapted for the stream loop # Let's use generate first as it makes cleaning easier output_obj = client.generate(prompt_text, **generate_kwargs) output = output_obj.generated_text # Get the final generated text # --- Post-processing the output --- # The model *should* generate only the assistant response after the ASSISTANT_HEADER. # However, sometimes leading whitespace or unexpected tokens can occur. # Let's strip potential leading whitespace and the expected header if it's accidentally generated. # Crucially, remove the ASSISTANT_HEADER that was part of the prompt structure expected_start = ASSISTANT_HEADER.strip() if output.startswith(expected_start): output = output[len(expected_start):].strip() else: # Fallback: If it doesn't start with the expected header, just strip leading/trailing whitespace output = output.strip() # Remove the EOT token if it's still present at the end if output.endswith(EOT_TOKEN): output = output[:-len(EOT_TOKEN)].strip() # --- End Post-processing --- # Ensure the output is not empty before translating if not output: # If the model produced no output after cleaning, return a default or error translated_output = "Error: Model did not produce a response or response was filtered." else: # Translate the cleaned response to Arabic translator = GoogleTranslator(source='auto', target='ar') translated_output = translator.translate(output) return {"response": translated_output} except HTTPException as he: # Re-raise HTTPExceptions (like the 400 for invalid history) raise he except Exception as e: # Log other errors for debugging purposes on the server print(f"An error occurred during generation or translation: {e}") raise HTTPException(status_code=500, detail=f"An internal error occurred: {e}")