Update main.py
Browse files
main.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import os
|
2 |
-
|
|
|
3 |
from fastapi import FastAPI, Form, HTTPException
|
4 |
from fastapi.middleware.cors import CORSMiddleware
|
5 |
-
|
6 |
from text_generation import Client
|
7 |
from deep_translator import GoogleTranslator
|
8 |
|
@@ -18,10 +19,13 @@ API_URL = "https://api-inference.huggingface.co/models/" + model_id
|
|
18 |
client = Client(
|
19 |
API_URL,
|
20 |
headers={"Authorization": f"Bearer {HF_TOKEN}"},
|
|
|
21 |
)
|
22 |
|
23 |
-
|
24 |
-
|
|
|
|
|
25 |
|
26 |
app = FastAPI()
|
27 |
|
@@ -34,32 +38,64 @@ app.add_middleware(
|
|
34 |
allow_headers=["*"],
|
35 |
)
|
36 |
|
37 |
-
# Pydantic model for request body
|
38 |
-
class ChatRequest(BaseModel):
|
39 |
-
prompt: str
|
40 |
-
history: List[Tuple[str, str]]
|
41 |
-
|
42 |
DEFAULT_SYSTEM_PROMPT = """\
|
43 |
-
You are a helpful, respectful and honest assistant
|
44 |
"""
|
45 |
|
|
|
46 |
def get_prompt(message: str, chat_history: List[Tuple[str, str]],
|
47 |
-
system_prompt: str) -> str:
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
for user_input, response in chat_history:
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
@app.post("/generate/")
|
59 |
async def generate_response(prompt: str = Form(...), history: str = Form(...)):
|
60 |
try:
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
message = prompt
|
64 |
|
65 |
prompt_text = get_prompt(message, chat_history, system_prompt)
|
@@ -69,23 +105,50 @@ async def generate_response(prompt: str = Form(...), history: str = Form(...)):
|
|
69 |
do_sample=True,
|
70 |
top_p=0.9,
|
71 |
top_k=50,
|
72 |
-
temperature=0.1,
|
|
|
|
|
73 |
)
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
#
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
return {"response": translated_output}
|
88 |
|
|
|
|
|
|
|
89 |
except Exception as e:
|
90 |
-
|
91 |
-
|
|
|
|
1 |
import os
|
2 |
+
import ast # Import ast for safer evaluation of literals
|
3 |
+
from typing import List, Tuple, Any # Import Any for better type hint after literal_eval
|
4 |
from fastapi import FastAPI, Form, HTTPException
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
6 |
+
# Removed BaseModel as we are using Form
|
7 |
from text_generation import Client
|
8 |
from deep_translator import GoogleTranslator
|
9 |
|
|
|
19 |
client = Client(
|
20 |
API_URL,
|
21 |
headers={"Authorization": f"Bearer {HF_TOKEN}"},
|
22 |
+
timeout=120 # Add a timeout for the client
|
23 |
)
|
24 |
|
25 |
+
# Correct End Of Text token for Llama 3 / 3.1
|
26 |
+
EOT_TOKEN = "<|eot_id|>"
|
27 |
+
# Expected header before the assistant's response starts in the Llama 3 format
|
28 |
+
ASSISTANT_HEADER = "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
29 |
|
30 |
app = FastAPI()
|
31 |
|
|
|
38 |
allow_headers=["*"],
|
39 |
)
|
40 |
|
|
|
|
|
|
|
|
|
|
|
41 |
DEFAULT_SYSTEM_PROMPT = """\
|
42 |
+
You are a helpful, respectful and honest assistant . Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
|
43 |
"""
|
44 |
|
45 |
+
# Updated get_prompt function using Llama 3.1 instruction format
|
46 |
def get_prompt(message: str, chat_history: List[Tuple[str, str]],
|
47 |
+
system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
|
48 |
+
"""
|
49 |
+
Formats the chat history and current message into the Llama 3.1 instruction format.
|
50 |
+
"""
|
51 |
+
prompt_parts = []
|
52 |
+
prompt_parts.append("<|begin_of_text|>")
|
53 |
+
|
54 |
+
# Add system prompt if provided
|
55 |
+
if system_prompt:
|
56 |
+
prompt_parts.append(f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}{EOT_TOKEN}")
|
57 |
+
|
58 |
+
# Add previous chat turns
|
59 |
for user_input, response in chat_history:
|
60 |
+
# Ensure inputs/responses are strings before including them
|
61 |
+
user_input_str = str(user_input).strip()
|
62 |
+
response_str = str(response).strip() if response is not None else "" # Handle potential None in history
|
63 |
+
|
64 |
+
prompt_parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{user_input_str}{EOT_TOKEN}")
|
65 |
+
# Ensure response is not empty before adding assistant turn
|
66 |
+
if response_str:
|
67 |
+
prompt_parts.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n{response_str}{EOT_TOKEN}")
|
68 |
|
69 |
+
# Add current user message and prepare for assistant response
|
70 |
+
message_str = str(message).strip()
|
71 |
+
prompt_parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{message_str}{EOT_TOKEN}")
|
72 |
+
prompt_parts.append(ASSISTANT_HEADER) # This is where the model starts generating
|
73 |
+
|
74 |
+
return "".join(prompt_parts)
|
75 |
+
|
76 |
+
# Keep app.post with Form data parameters
|
77 |
@app.post("/generate/")
|
78 |
async def generate_response(prompt: str = Form(...), history: str = Form(...)):
|
79 |
try:
|
80 |
+
# --- SAFELY Parse History ---
|
81 |
+
# Replace eval() with ast.literal_eval() for safety
|
82 |
+
# It can safely evaluate strings containing Python literals (like lists, tuples, strings, numbers, dicts, booleans, None)
|
83 |
+
try:
|
84 |
+
parsed_history: Any = ast.literal_eval(history)
|
85 |
+
|
86 |
+
# Basic validation to ensure it looks like the expected format
|
87 |
+
if not isinstance(parsed_history, list):
|
88 |
+
raise ValueError("History is not a list.")
|
89 |
+
# You could add more checks, e.g., if items are tuples of strings
|
90 |
+
|
91 |
+
chat_history: List[Tuple[str, str]] = [(str(u), str(a)) for u, a in parsed_history] # Ensure elements are strings
|
92 |
+
|
93 |
+
except (ValueError, SyntaxError, TypeError) as e:
|
94 |
+
# Catch errors if the history string is not a valid literal or not the right structure
|
95 |
+
raise HTTPException(status_code=400, detail=f"Invalid history format: {e}")
|
96 |
+
# --- End Safely Parse History ---
|
97 |
+
|
98 |
+
system_prompt = DEFAULT_SYSTEM_PROMPT # You could make this configurable too
|
99 |
message = prompt
|
100 |
|
101 |
prompt_text = get_prompt(message, chat_history, system_prompt)
|
|
|
105 |
do_sample=True,
|
106 |
top_p=0.9,
|
107 |
top_k=50,
|
108 |
+
temperature=0.1, # Keep temperature low for more predictable output
|
109 |
+
stop_sequences=[EOT_TOKEN], # Explicitly tell the API to stop at EOT
|
110 |
+
# return_full_text=False # Might need to experiment with this depending on API behavior
|
111 |
)
|
112 |
+
|
113 |
+
# Using generate (non-streaming) is simpler for post-processing
|
114 |
+
# If you need streaming, the logic below needs to be adapted for the stream loop
|
115 |
+
# Let's use generate first as it makes cleaning easier
|
116 |
+
output_obj = client.generate(prompt_text, **generate_kwargs)
|
117 |
+
output = output_obj.generated_text # Get the final generated text
|
118 |
+
|
119 |
+
# --- Post-processing the output ---
|
120 |
+
# The model *should* generate only the assistant response after the ASSISTANT_HEADER.
|
121 |
+
# However, sometimes leading whitespace or unexpected tokens can occur.
|
122 |
+
# Let's strip potential leading whitespace and the expected header if it's accidentally generated.
|
123 |
+
# Crucially, remove the ASSISTANT_HEADER that was part of the prompt structure
|
124 |
+
expected_start = ASSISTANT_HEADER.strip()
|
125 |
+
if output.startswith(expected_start):
|
126 |
+
output = output[len(expected_start):].strip()
|
127 |
+
else:
|
128 |
+
# Fallback: If it doesn't start with the expected header, just strip leading/trailing whitespace
|
129 |
+
output = output.strip()
|
130 |
+
|
131 |
+
|
132 |
+
# Remove the EOT token if it's still present at the end
|
133 |
+
if output.endswith(EOT_TOKEN):
|
134 |
+
output = output[:-len(EOT_TOKEN)].strip()
|
135 |
+
# --- End Post-processing ---
|
136 |
+
|
137 |
+
# Ensure the output is not empty before translating
|
138 |
+
if not output:
|
139 |
+
# If the model produced no output after cleaning, return a default or error
|
140 |
+
translated_output = "Error: Model did not produce a response or response was filtered."
|
141 |
+
else:
|
142 |
+
# Translate the cleaned response to Arabic
|
143 |
+
translator = GoogleTranslator(source='auto', target='ar')
|
144 |
+
translated_output = translator.translate(output)
|
145 |
|
146 |
return {"response": translated_output}
|
147 |
|
148 |
+
except HTTPException as he:
|
149 |
+
# Re-raise HTTPExceptions (like the 400 for invalid history)
|
150 |
+
raise he
|
151 |
except Exception as e:
|
152 |
+
# Log other errors for debugging purposes on the server
|
153 |
+
print(f"An error occurred during generation or translation: {e}")
|
154 |
+
raise HTTPException(status_code=500, detail=f"An internal error occurred: {e}")
|