L3 / main.py
Luisgust's picture
Create main.py
1634927 verified
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
# from huggingface_hub import InferenceClient # Remove this line
import json
import os
from groq import Groq # Import the Groq client
app = FastAPI()
# Initialize the Groq client
# It's recommended to set GROQ_API_KEY environment variable
client = Groq(
api_key=os.environ.get("GROQ_API_KEY"),
)
SYSTEM_MESSAGE = (
"You are a helpful, respectful and honest assistant. Always answer as helpfully "
"as possible, while being safe. Your answers should not include any harmful, "
"unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure "
"that your responses are socially unbiased and positive in nature.\n\nIf a question "
"does not make any sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer to a question, please "
"don't share false information."
"Always respond in the language of user prompt for each prompt ."
)
MAX_TOKENS = 2000
TEMPERATURE = 0.7
TOP_P = 0.95
# Set the Groq model name
GROQ_MODEL_NAME = "llama3-8b-8192" # This is the correct model name [1, 2, 8]
def respond(message, history: list[tuple[str, str]]):
messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
# Use the Groq client for chat completion
# Set stream=True to get a streaming response [4, 12, 13]
response = client.chat.completions.create(
messages=messages,
model=GROQ_MODEL_NAME,
max_tokens=MAX_TOKENS,
stream=True,
temperature=TEMPERATURE,
top_p=TOP_P,
)
# Iterate over the streaming response
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["https://artixiban-ll3.static.hf.space"], # Allow only this origin
allow_credentials=True,
allow_methods=["*"], # Allow all methods (GET, POST, etc.)
allow_headers=["*"], # Allow all headers
)
@app.post("/generate/")
async def generate(request: Request):
allowed_origin = "https://artixiban-ll3.static.hf.space"
origin = request.headers.get("origin")
if origin != allowed_origin:
raise HTTPException(status_code=403, detail="Origin not allowed")
form = await request.form()
prompt = form.get("prompt")
history = json.loads(form.get("history", "[]")) # Default to empty history
if not prompt:
raise HTTPException(status_code=400, detail="Prompt is required")
response_generator = respond(prompt, history)
final_response = ""
# The respond function is already a generator yielding chunks
for part in response_generator:
final_response += part
return JSONResponse(content={"response": final_response})