Luisgust commited on
Commit
2c53650
·
verified ·
1 Parent(s): fa44779

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +99 -36
main.py CHANGED
@@ -1,8 +1,9 @@
1
  import os
2
- from typing import List, Tuple
 
3
  from fastapi import FastAPI, Form, HTTPException
4
  from fastapi.middleware.cors import CORSMiddleware
5
- from pydantic import BaseModel
6
  from text_generation import Client
7
  from deep_translator import GoogleTranslator
8
 
@@ -18,10 +19,13 @@ API_URL = "https://api-inference.huggingface.co/models/" + model_id
18
  client = Client(
19
  API_URL,
20
  headers={"Authorization": f"Bearer {HF_TOKEN}"},
 
21
  )
22
 
23
- EOS_STRING = "</s>"
24
- EOT_STRING = "<EOT>"
 
 
25
 
26
  app = FastAPI()
27
 
@@ -34,32 +38,64 @@ app.add_middleware(
34
  allow_headers=["*"],
35
  )
36
 
37
- # Pydantic model for request body
38
- class ChatRequest(BaseModel):
39
- prompt: str
40
- history: List[Tuple[str, str]]
41
-
42
  DEFAULT_SYSTEM_PROMPT = """\
43
- You are a helpful, respectful and honest assistant with a deep knowledge of code and software design. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
44
  """
45
 
 
46
  def get_prompt(message: str, chat_history: List[Tuple[str, str]],
47
- system_prompt: str) -> str:
48
- texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
49
- do_strip = False
 
 
 
 
 
 
 
 
 
50
  for user_input, response in chat_history:
51
- user_input = user_input.strip() if do_strip else user_input
52
- do_strip = True
53
- texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
54
- message = message.strip() if do_strip else message
55
- texts.append(f'{message} [/INST]')
56
- return ''.join(texts)
 
 
57
 
 
 
 
 
 
 
 
 
58
  @app.post("/generate/")
59
  async def generate_response(prompt: str = Form(...), history: str = Form(...)):
60
  try:
61
- chat_history = eval(history) # Convert history string back to list
62
- system_prompt = DEFAULT_SYSTEM_PROMPT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  message = prompt
64
 
65
  prompt_text = get_prompt(message, chat_history, system_prompt)
@@ -69,23 +105,50 @@ async def generate_response(prompt: str = Form(...), history: str = Form(...)):
69
  do_sample=True,
70
  top_p=0.9,
71
  top_k=50,
72
- temperature=0.1,
 
 
73
  )
74
-
75
- stream = client.generate_stream(prompt_text, **generate_kwargs)
76
- output = ""
77
- for response in stream:
78
- if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
79
- break
80
- else:
81
- output += response.token.text
82
-
83
- # Translate the response to Arabic
84
- translator = GoogleTranslator(source='auto', target='ar')
85
- translated_output = translator.translate(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  return {"response": translated_output}
88
 
 
 
 
89
  except Exception as e:
90
- raise HTTPException(status_code=500, detail=str(e))
91
-
 
 
1
  import os
2
+ import ast # Import ast for safer evaluation of literals
3
+ from typing import List, Tuple, Any # Import Any for better type hint after literal_eval
4
  from fastapi import FastAPI, Form, HTTPException
5
  from fastapi.middleware.cors import CORSMiddleware
6
+ # Removed BaseModel as we are using Form
7
  from text_generation import Client
8
  from deep_translator import GoogleTranslator
9
 
 
19
  client = Client(
20
  API_URL,
21
  headers={"Authorization": f"Bearer {HF_TOKEN}"},
22
+ timeout=120 # Add a timeout for the client
23
  )
24
 
25
+ # Correct End Of Text token for Llama 3 / 3.1
26
+ EOT_TOKEN = "<|eot_id|>"
27
+ # Expected header before the assistant's response starts in the Llama 3 format
28
+ ASSISTANT_HEADER = "<|start_header_id|>assistant<|end_header_id|>\n\n"
29
 
30
  app = FastAPI()
31
 
 
38
  allow_headers=["*"],
39
  )
40
 
 
 
 
 
 
41
  DEFAULT_SYSTEM_PROMPT = """\
42
+ You are a helpful, respectful and honest assistant . Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
43
  """
44
 
45
+ # Updated get_prompt function using Llama 3.1 instruction format
46
  def get_prompt(message: str, chat_history: List[Tuple[str, str]],
47
+ system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
48
+ """
49
+ Formats the chat history and current message into the Llama 3.1 instruction format.
50
+ """
51
+ prompt_parts = []
52
+ prompt_parts.append("<|begin_of_text|>")
53
+
54
+ # Add system prompt if provided
55
+ if system_prompt:
56
+ prompt_parts.append(f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}{EOT_TOKEN}")
57
+
58
+ # Add previous chat turns
59
  for user_input, response in chat_history:
60
+ # Ensure inputs/responses are strings before including them
61
+ user_input_str = str(user_input).strip()
62
+ response_str = str(response).strip() if response is not None else "" # Handle potential None in history
63
+
64
+ prompt_parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{user_input_str}{EOT_TOKEN}")
65
+ # Ensure response is not empty before adding assistant turn
66
+ if response_str:
67
+ prompt_parts.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n{response_str}{EOT_TOKEN}")
68
 
69
+ # Add current user message and prepare for assistant response
70
+ message_str = str(message).strip()
71
+ prompt_parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{message_str}{EOT_TOKEN}")
72
+ prompt_parts.append(ASSISTANT_HEADER) # This is where the model starts generating
73
+
74
+ return "".join(prompt_parts)
75
+
76
+ # Keep app.post with Form data parameters
77
  @app.post("/generate/")
78
  async def generate_response(prompt: str = Form(...), history: str = Form(...)):
79
  try:
80
+ # --- SAFELY Parse History ---
81
+ # Replace eval() with ast.literal_eval() for safety
82
+ # It can safely evaluate strings containing Python literals (like lists, tuples, strings, numbers, dicts, booleans, None)
83
+ try:
84
+ parsed_history: Any = ast.literal_eval(history)
85
+
86
+ # Basic validation to ensure it looks like the expected format
87
+ if not isinstance(parsed_history, list):
88
+ raise ValueError("History is not a list.")
89
+ # You could add more checks, e.g., if items are tuples of strings
90
+
91
+ chat_history: List[Tuple[str, str]] = [(str(u), str(a)) for u, a in parsed_history] # Ensure elements are strings
92
+
93
+ except (ValueError, SyntaxError, TypeError) as e:
94
+ # Catch errors if the history string is not a valid literal or not the right structure
95
+ raise HTTPException(status_code=400, detail=f"Invalid history format: {e}")
96
+ # --- End Safely Parse History ---
97
+
98
+ system_prompt = DEFAULT_SYSTEM_PROMPT # You could make this configurable too
99
  message = prompt
100
 
101
  prompt_text = get_prompt(message, chat_history, system_prompt)
 
105
  do_sample=True,
106
  top_p=0.9,
107
  top_k=50,
108
+ temperature=0.1, # Keep temperature low for more predictable output
109
+ stop_sequences=[EOT_TOKEN], # Explicitly tell the API to stop at EOT
110
+ # return_full_text=False # Might need to experiment with this depending on API behavior
111
  )
112
+
113
+ # Using generate (non-streaming) is simpler for post-processing
114
+ # If you need streaming, the logic below needs to be adapted for the stream loop
115
+ # Let's use generate first as it makes cleaning easier
116
+ output_obj = client.generate(prompt_text, **generate_kwargs)
117
+ output = output_obj.generated_text # Get the final generated text
118
+
119
+ # --- Post-processing the output ---
120
+ # The model *should* generate only the assistant response after the ASSISTANT_HEADER.
121
+ # However, sometimes leading whitespace or unexpected tokens can occur.
122
+ # Let's strip potential leading whitespace and the expected header if it's accidentally generated.
123
+ # Crucially, remove the ASSISTANT_HEADER that was part of the prompt structure
124
+ expected_start = ASSISTANT_HEADER.strip()
125
+ if output.startswith(expected_start):
126
+ output = output[len(expected_start):].strip()
127
+ else:
128
+ # Fallback: If it doesn't start with the expected header, just strip leading/trailing whitespace
129
+ output = output.strip()
130
+
131
+
132
+ # Remove the EOT token if it's still present at the end
133
+ if output.endswith(EOT_TOKEN):
134
+ output = output[:-len(EOT_TOKEN)].strip()
135
+ # --- End Post-processing ---
136
+
137
+ # Ensure the output is not empty before translating
138
+ if not output:
139
+ # If the model produced no output after cleaning, return a default or error
140
+ translated_output = "Error: Model did not produce a response or response was filtered."
141
+ else:
142
+ # Translate the cleaned response to Arabic
143
+ translator = GoogleTranslator(source='auto', target='ar')
144
+ translated_output = translator.translate(output)
145
 
146
  return {"response": translated_output}
147
 
148
+ except HTTPException as he:
149
+ # Re-raise HTTPExceptions (like the 400 for invalid history)
150
+ raise he
151
  except Exception as e:
152
+ # Log other errors for debugging purposes on the server
153
+ print(f"An error occurred during generation or translation: {e}")
154
+ raise HTTPException(status_code=500, detail=f"An internal error occurred: {e}")