Spaces:
Build error
Build error
Daniel Marques
commited on
Commit
·
5d66516
1
Parent(s):
57fab83
feat: add broadcast
Browse files- main.py +46 -20
- redisPubSubManger.py +69 -0
- requirements.txt +1 -1
- run.sh +1 -2
- test.txt +0 -1
- webSocketManger.py +70 -0
main.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
-
from typing import Any, Dict, Union
|
| 2 |
-
|
| 3 |
import os
|
| 4 |
import glob
|
| 5 |
import shutil
|
| 6 |
import subprocess
|
| 7 |
import torch
|
|
|
|
| 8 |
|
| 9 |
from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
| 10 |
from fastapi.staticfiles import StaticFiles
|
|
|
|
| 11 |
|
| 12 |
from pydantic import BaseModel
|
| 13 |
|
|
@@ -55,6 +55,8 @@ QA = RetrievalQA.from_chain_type(
|
|
| 55 |
},
|
| 56 |
)
|
| 57 |
|
|
|
|
|
|
|
| 58 |
app = FastAPI(title="homepage-app")
|
| 59 |
api_app = FastAPI(title="api app")
|
| 60 |
|
|
@@ -162,8 +164,6 @@ def predict(data: Predict):
|
|
| 162 |
except Exception as e:
|
| 163 |
raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}")
|
| 164 |
|
| 165 |
-
|
| 166 |
-
|
| 167 |
@api_app.post("/save_document/")
|
| 168 |
async def create_upload_file(file: UploadFile):
|
| 169 |
# Get the file size (in bytes)
|
|
@@ -204,31 +204,57 @@ async def create_upload_file(file: UploadFile):
|
|
| 204 |
|
| 205 |
return {"filename": file.filename}
|
| 206 |
|
| 207 |
-
@api_app.websocket("/ws/{
|
| 208 |
-
async def websocket_endpoint(websocket: WebSocket,
|
| 209 |
global QA
|
| 210 |
|
| 211 |
-
await
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
try:
|
| 214 |
while True:
|
| 215 |
-
|
| 216 |
-
response = QA(inputs=user_prompt, return_only_outputs=True, tags=f'{client_id}', include_run_info=True)
|
| 217 |
-
answer, docs = response["result"], response["source_documents"]
|
| 218 |
|
| 219 |
-
|
| 220 |
-
"
|
| 221 |
-
"
|
|
|
|
| 222 |
}
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
except WebSocketDisconnect:
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
except RuntimeError as error:
|
| 234 |
print(error)
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import glob
|
| 3 |
import shutil
|
| 4 |
import subprocess
|
| 5 |
import torch
|
| 6 |
+
import json
|
| 7 |
|
| 8 |
from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
| 9 |
from fastapi.staticfiles import StaticFiles
|
| 10 |
+
from websocket.socketManager import WebSocketManager
|
| 11 |
|
| 12 |
from pydantic import BaseModel
|
| 13 |
|
|
|
|
| 55 |
},
|
| 56 |
)
|
| 57 |
|
| 58 |
+
socket_manager = WebSocketManager()
|
| 59 |
+
|
| 60 |
app = FastAPI(title="homepage-app")
|
| 61 |
api_app = FastAPI(title="api app")
|
| 62 |
|
|
|
|
| 164 |
except Exception as e:
|
| 165 |
raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}")
|
| 166 |
|
|
|
|
|
|
|
| 167 |
@api_app.post("/save_document/")
|
| 168 |
async def create_upload_file(file: UploadFile):
|
| 169 |
# Get the file size (in bytes)
|
|
|
|
| 204 |
|
| 205 |
return {"filename": file.filename}
|
| 206 |
|
| 207 |
+
@api_app.websocket("/ws/{room_id}/{user_id}")
|
| 208 |
+
async def websocket_endpoint(websocket: WebSocket, room_id: str, user_id: int):
|
| 209 |
global QA
|
| 210 |
|
| 211 |
+
await socket_manager.add_user_to_room(room_id, websocket)
|
| 212 |
+
|
| 213 |
+
message = {
|
| 214 |
+
"user_id": user_id,
|
| 215 |
+
"room_id": room_id,
|
| 216 |
+
"message": f"User {user_id} connected to room - {room_id}"
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
await socket_manager.broadcast_to_room(room_id, json.dumps(message))
|
| 220 |
|
| 221 |
try:
|
| 222 |
while True:
|
| 223 |
+
data = await websocket.receive_text()
|
|
|
|
|
|
|
| 224 |
|
| 225 |
+
message = {
|
| 226 |
+
"user_id": user_id,
|
| 227 |
+
"room_id": room_id,
|
| 228 |
+
"message": data
|
| 229 |
}
|
| 230 |
|
| 231 |
+
await socket_manager.broadcast_to_room(room_id, json.dumps(message))
|
| 232 |
+
|
| 233 |
+
# user_prompt = await websocket.receive_text()
|
| 234 |
+
# response = QA(inputs=user_prompt, return_only_outputs=True, tags=f'{client_id}', include_run_info=True)
|
| 235 |
+
# answer, docs = response["result"], response["source_documents"]
|
| 236 |
+
|
| 237 |
+
# prompt_response_dict = {
|
| 238 |
+
# "Prompt": user_prompt,
|
| 239 |
+
# "Answer": answer,
|
| 240 |
+
# }
|
| 241 |
+
|
| 242 |
+
# prompt_response_dict["Sources"] = []
|
| 243 |
+
# for document in docs:
|
| 244 |
+
# prompt_response_dict["Sources"].append(
|
| 245 |
+
# (os.path.basename(str(document.metadata["source"])), str(document.page_content))
|
| 246 |
+
# )
|
| 247 |
+
# await websocket.send_json(prompt_response_dict)
|
| 248 |
|
| 249 |
except WebSocketDisconnect:
|
| 250 |
+
await socket_manager.remove_user_from_room(room_id, websocket)
|
| 251 |
+
|
| 252 |
+
message = {
|
| 253 |
+
"user_id": user_id,
|
| 254 |
+
"room_id": room_id,
|
| 255 |
+
"message": f"User {user_id} disconnected from room - {room_id}"
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
await socket_manager.broadcast_to_room(room_id, json.dumps(message))
|
| 259 |
except RuntimeError as error:
|
| 260 |
print(error)
|
redisPubSubManger.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import redis.asyncio as aioredis
|
| 3 |
+
import json
|
| 4 |
+
from fastapi import WebSocket
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RedisPubSubManager:
|
| 8 |
+
"""
|
| 9 |
+
Initializes the RedisPubSubManager.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
host (str): Redis server host.
|
| 13 |
+
port (int): Redis server port.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, host='localhost', port=6379):
|
| 17 |
+
self.redis_host = host
|
| 18 |
+
self.redis_port = port
|
| 19 |
+
self.pubsub = None
|
| 20 |
+
|
| 21 |
+
async def _get_redis_connection(self) -> aioredis.Redis:
|
| 22 |
+
"""
|
| 23 |
+
Establishes a connection to Redis.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
aioredis.Redis: Redis connection object.
|
| 27 |
+
"""
|
| 28 |
+
return aioredis.Redis(host=self.redis_host,
|
| 29 |
+
port=self.redis_port,
|
| 30 |
+
auto_close_connection_pool=False)
|
| 31 |
+
|
| 32 |
+
async def connect(self) -> None:
|
| 33 |
+
"""
|
| 34 |
+
Connects to the Redis server and initializes the pubsub client.
|
| 35 |
+
"""
|
| 36 |
+
self.redis_connection = await self._get_redis_connection()
|
| 37 |
+
self.pubsub = self.redis_connection.pubsub()
|
| 38 |
+
|
| 39 |
+
async def _publish(self, room_id: str, message: str) -> None:
|
| 40 |
+
"""
|
| 41 |
+
Publishes a message to a specific Redis channel.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
room_id (str): Channel or room ID.
|
| 45 |
+
message (str): Message to be published.
|
| 46 |
+
"""
|
| 47 |
+
await self.redis_connection.publish(room_id, message)
|
| 48 |
+
|
| 49 |
+
async def subscribe(self, room_id: str) -> aioredis.Redis:
|
| 50 |
+
"""
|
| 51 |
+
Subscribes to a Redis channel.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
room_id (str): Channel or room ID to subscribe to.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
aioredis.ChannelSubscribe: PubSub object for the subscribed channel.
|
| 58 |
+
"""
|
| 59 |
+
await self.pubsub.subscribe(room_id)
|
| 60 |
+
return self.pubsub
|
| 61 |
+
|
| 62 |
+
async def unsubscribe(self, room_id: str) -> None:
|
| 63 |
+
"""
|
| 64 |
+
Unsubscribes from a Redis channel.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
room_id (str): Channel or room ID to unsubscribe from.
|
| 68 |
+
"""
|
| 69 |
+
await self.pubsub.unsubscribe(room_id)
|
requirements.txt
CHANGED
|
@@ -29,7 +29,7 @@ uvicorn
|
|
| 29 |
fastapi
|
| 30 |
websockets
|
| 31 |
pydantic
|
| 32 |
-
|
| 33 |
|
| 34 |
# Streamlit related
|
| 35 |
streamlit
|
|
|
|
| 29 |
fastapi
|
| 30 |
websockets
|
| 31 |
pydantic
|
| 32 |
+
aioredis
|
| 33 |
|
| 34 |
# Streamlit related
|
| 35 |
streamlit
|
run.sh
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
# Redis Support uncomment this lines
|
| 2 |
-
|
| 3 |
-
# nohup redis-server &
|
| 4 |
|
| 5 |
uvicorn "main:app" --port 7860 --host 0.0.0.0
|
|
|
|
| 1 |
# Redis Support uncomment this lines
|
| 2 |
+
nohup redis-server &
|
|
|
|
| 3 |
|
| 4 |
uvicorn "main:app" --port 7860 --host 0.0.0.0
|
test.txt
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
dkdaniz is an avatar of instagram, create by daniel marques
|
|
|
|
|
|
webSocketManger.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class WebSocketManager:
|
| 2 |
+
def __init__(self):
|
| 3 |
+
"""
|
| 4 |
+
Initializes the WebSocketManager.
|
| 5 |
+
|
| 6 |
+
Attributes:
|
| 7 |
+
rooms (dict): A dictionary to store WebSocket connections in different rooms.
|
| 8 |
+
pubsub_client (RedisPubSubManager): An instance of the RedisPubSubManager class for pub-sub functionality.
|
| 9 |
+
"""
|
| 10 |
+
self.rooms: dict = {}
|
| 11 |
+
self.pubsub_client = RedisPubSubManager()
|
| 12 |
+
|
| 13 |
+
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
|
| 14 |
+
"""
|
| 15 |
+
Adds a user's WebSocket connection to a room.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
room_id (str): Room ID or channel name.
|
| 19 |
+
websocket (WebSocket): WebSocket connection object.
|
| 20 |
+
"""
|
| 21 |
+
await websocket.accept()
|
| 22 |
+
|
| 23 |
+
if room_id in self.rooms:
|
| 24 |
+
self.rooms[room_id].append(websocket)
|
| 25 |
+
else:
|
| 26 |
+
self.rooms[room_id] = [websocket]
|
| 27 |
+
|
| 28 |
+
await self.pubsub_client.connect()
|
| 29 |
+
pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
|
| 30 |
+
asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
|
| 31 |
+
|
| 32 |
+
async def broadcast_to_room(self, room_id: str, message: str) -> None:
|
| 33 |
+
"""
|
| 34 |
+
Broadcasts a message to all connected WebSockets in a room.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
room_id (str): Room ID or channel name.
|
| 38 |
+
message (str): Message to be broadcasted.
|
| 39 |
+
"""
|
| 40 |
+
await self.pubsub_client._publish(room_id, message)
|
| 41 |
+
|
| 42 |
+
async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
|
| 43 |
+
"""
|
| 44 |
+
Removes a user's WebSocket connection from a room.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
room_id (str): Room ID or channel name.
|
| 48 |
+
websocket (WebSocket): WebSocket connection object.
|
| 49 |
+
"""
|
| 50 |
+
self.rooms[room_id].remove(websocket)
|
| 51 |
+
|
| 52 |
+
if len(self.rooms[room_id]) == 0:
|
| 53 |
+
del self.rooms[room_id]
|
| 54 |
+
await self.pubsub_client.unsubscribe(room_id)
|
| 55 |
+
|
| 56 |
+
async def _pubsub_data_reader(self, pubsub_subscriber):
|
| 57 |
+
"""
|
| 58 |
+
Reads and broadcasts messages received from Redis PubSub.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
pubsub_subscriber (aioredis.ChannelSubscribe): PubSub object for the subscribed channel.
|
| 62 |
+
"""
|
| 63 |
+
while True:
|
| 64 |
+
message = await pubsub_subscriber.get_message(ignore_subscribe_messages=True)
|
| 65 |
+
if message is not None:
|
| 66 |
+
room_id = message['channel'].decode('utf-8')
|
| 67 |
+
all_sockets = self.rooms[room_id]
|
| 68 |
+
for socket in all_sockets:
|
| 69 |
+
data = message['data'].decode('utf-8')
|
| 70 |
+
await socket.send_text(data)
|