|
from asyncio import sleep |
|
from typing import Optional |
|
from fastapi import FastAPI |
|
from fastapi.encoders import jsonable_encoder |
|
from fastapi.websockets import WebSocket, WebSocketDisconnect |
|
from fastapi.responses import HTMLResponse, JSONResponse |
|
from websockets import ConnectionClosed |
|
|
|
from accelerator import Accelerator |
|
from answerer import Answerer |
|
from mapper import Mapper |
|
|
|
try: mapper = Mapper("sentence-transformers/multi-qa-distilbert-cos-v1") |
|
except Exception as e: print(f"ERROR! cannot load Mapper model!\n{e}") |
|
|
|
answerer = Answerer( |
|
model="RWKV-5-World-3B-v2-20231118-ctx16k", |
|
vocab="rwkv_vocab_v20230424", |
|
strategy="cpu bf16", |
|
ctx_limit=16*1024, |
|
) |
|
|
|
accelerator = Accelerator() |
|
|
|
app = FastAPI() |
|
|
|
HTML = """ |
|
<!DOCTYPE HTML> |
|
|
|
<html> |
|
|
|
<body> |
|
<form action="" onsubmit="ask(event)"> |
|
<textarea id="prompt"></textarea> |
|
<br> |
|
<input type="submit" value="SEND" /> |
|
</form> |
|
|
|
<p id="output"></p> |
|
<script> |
|
const prompt = document.getElementById("prompt"); |
|
const output = document.getElementById("output"); |
|
|
|
const ws = new WebSocket("wss://daniilalpha-answerer-api.hf.space/answer"); |
|
ws.onmessage = (e) => answer(e.data); |
|
|
|
function ask(event) { |
|
if(ws.readyState != 1) { |
|
answer("websocket is not connected!"); |
|
return; |
|
} |
|
|
|
ws.send(prompt.value); |
|
event.preventDefault(); |
|
} |
|
|
|
function answer(value) { |
|
output.innerHTML = value; |
|
} |
|
</script> |
|
</body> |
|
|
|
</html> |
|
""" |
|
|
|
@app.get("/") |
|
def index(): |
|
return HTMLResponse(HTML) |
|
|
|
@app.websocket("/accelerate") |
|
async def answer(ws: WebSocket): |
|
await accelerator.connect(ws) |
|
while accelerator.connected(): |
|
await sleep(10) |
|
|
|
@app.post("/map") |
|
def map(query: Optional[str], items: Optional[list[str]]): |
|
scores = mapper(query, items) |
|
return JSONResponse(jsonable_encoder(scores)) |
|
|
|
async def handle_answerer_local(ws: WebSocket, input: str): |
|
output = answerer(input, 128) |
|
el: str |
|
async for el in output: pass |
|
await ws.send_text(el) |
|
|
|
async def handle_answerer_accelerated(ws: WebSocket, input: str): |
|
output = await accelerator.accelerate(input) |
|
if output: await ws.send_text(output) |
|
else: await handle_answerer_local(ws, input) |
|
|
|
@app.websocket("/answer") |
|
async def answer(ws: WebSocket): |
|
await ws.accept() |
|
|
|
try: |
|
input = await ws.receive_text() |
|
if accelerator.connected(): await handle_answerer_accelerated(ws, input) |
|
else: await handle_answerer_local(ws, input) |
|
except ConnectionClosed: return |
|
except WebSocketDisconnect: return |
|
|
|
await ws.close() |
|
|