App / main.py
Yjhhh's picture
Create main.py
bd132d0 verified
raw
history blame
2.11 kB
import json
import os
import uuid
from typing import AsyncGenerator, NoReturn
import google.generativeai as genai
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, WebSocket
load_dotenv()
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
model = genai.GenerativeModel("gemini-pro")
app = FastAPI()
PROMPT = """
You are a helpful assistant, skilled in explaining complex concepts in simple terms.
{message}
""" # noqa: E501
IMAGE_PROMPT = """
Generate an image based on the following description:
{description}
""" # noqa: E501
async def get_ai_response(message: str) -> AsyncGenerator[str, None]:
"""
Gemini Response
"""
response = await model.generate_content_async(
PROMPT.format(message=message), stream=True
)
msg_id = str(uuid.uuid4())
all_text = ""
async for chunk in response:
if chunk.candidates:
for part in chunk.candidates[0].content.parts:
all_text += part.text
yield json.dumps({"id": msg_id, "text": all_text})
async def get_ai_image(description: str) -> str:
"""
Gemini Image Generation
"""
response = await model.generate_image_async(
IMAGE_PROMPT.format(description=description)
)
if response.images:
# Assuming we take the first generated image
return json.dumps({"image_url": response.images[0].url})
return json.dumps({"error": "No image generated"})
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket) -> NoReturn:
"""
Websocket for AI responses
"""
await websocket.accept()
while True:
message = await websocket.receive_text()
async for text in get_ai_response(message):
await websocket.send_text(text)
@app.post("/generate-image/")
async def generate_image_endpoint(description: str):
"""
Endpoint for AI image generation
"""
image_url = await get_ai_image(description)
return json.loads(image_url)
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=7860
)