Yjhhh commited on
Commit
bd132d0
1 Parent(s): 61e4bb6

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +83 -0
main.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import uuid
4
+ from typing import AsyncGenerator, NoReturn
5
+
6
+ import google.generativeai as genai
7
+ import uvicorn
8
+ from dotenv import load_dotenv
9
+ from fastapi import FastAPI, WebSocket
10
+
11
+ load_dotenv()
12
+
13
+ genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
14
+ model = genai.GenerativeModel("gemini-pro")
15
+
16
+ app = FastAPI()
17
+
18
+ PROMPT = """
19
+ You are a helpful assistant, skilled in explaining complex concepts in simple terms.
20
+
21
+ {message}
22
+ """ # noqa: E501
23
+
24
+ IMAGE_PROMPT = """
25
+ Generate an image based on the following description:
26
+
27
+ {description}
28
+ """ # noqa: E501
29
+
30
+ async def get_ai_response(message: str) -> AsyncGenerator[str, None]:
31
+ """
32
+ Gemini Response
33
+ """
34
+ response = await model.generate_content_async(
35
+ PROMPT.format(message=message), stream=True
36
+ )
37
+
38
+ msg_id = str(uuid.uuid4())
39
+ all_text = ""
40
+ async for chunk in response:
41
+ if chunk.candidates:
42
+ for part in chunk.candidates[0].content.parts:
43
+ all_text += part.text
44
+ yield json.dumps({"id": msg_id, "text": all_text})
45
+
46
+ async def get_ai_image(description: str) -> str:
47
+ """
48
+ Gemini Image Generation
49
+ """
50
+ response = await model.generate_image_async(
51
+ IMAGE_PROMPT.format(description=description)
52
+ )
53
+
54
+ if response.images:
55
+ # Assuming we take the first generated image
56
+ return json.dumps({"image_url": response.images[0].url})
57
+ return json.dumps({"error": "No image generated"})
58
+
59
+ @app.websocket("/ws")
60
+ async def websocket_endpoint(websocket: WebSocket) -> NoReturn:
61
+ """
62
+ Websocket for AI responses
63
+ """
64
+ await websocket.accept()
65
+ while True:
66
+ message = await websocket.receive_text()
67
+ async for text in get_ai_response(message):
68
+ await websocket.send_text(text)
69
+
70
+ @app.post("/generate-image/")
71
+ async def generate_image_endpoint(description: str):
72
+ """
73
+ Endpoint for AI image generation
74
+ """
75
+ image_url = await get_ai_image(description)
76
+ return json.loads(image_url)
77
+
78
+ if __name__ == "__main__":
79
+ uvicorn.run(
80
+ app,
81
+ host="0.0.0.0",
82
+ port=7860
83
+ )