File size: 1,108 Bytes
8cf7512 3f71f43 f1df6cd 3f71f43 2895d63 3f71f43 8cf7512 3f71f43 e37d48b 5cd819c 3f71f43 2895d63 3f71f43 cc568d3 3f71f43 2895d63 3f71f43 2895d63 3f71f43 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from fastapi import FastAPI
from reranker import RankLLM, RankListwiseOSLLM, Result, RankingExecInfo
from pydantic import BaseModel
from typing import Optional, List, Tuple
# load RankListwiseOSLLM
reranker = RankListwiseOSLLM("Salesforce/SweRankLLM-small")
class RerankRequest(BaseModel):
query: str
hits: List[Tuple[int, str]]
class RerankResponse(BaseModel):
query: str
hits: List[Tuple[int, str]]
app = FastAPI()
@app.get("/")
def hello_world():
return {"msg": "Success"}
@app.post("/rerank")
def rerank(request: RerankRequest):
hits = request.hits
sorted_hits = sorted(hits, key=lambda x: x[0]) # sort hits again for safety
result = Result(
query=request.query,
hits = [{"content": hit[1]} for hit in sorted_hits]
)
reranked_result = reranker.permutation_pipeline(
result,
0,
len(hits),
logging=True
)
reranked_hits = [(i + 1, item["content"]) for i, item in enumerate(reranked_result.hits)]
return {
"query": request.query,
"reranked_hits": reranked_hits
}
|