File size: 1,108 Bytes
8cf7512
3f71f43
 
 
 
 
f1df6cd
3f71f43
 
 
 
 
 
2895d63
3f71f43
8cf7512
 
 
 
 
3f71f43
 
 
e37d48b
5cd819c
3f71f43
 
 
 
2895d63
 
3f71f43
 
 
 
cc568d3
3f71f43
 
 
 
2895d63
3f71f43
2895d63
 
 
 
3f71f43
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from fastapi import FastAPI
from reranker import RankLLM, RankListwiseOSLLM, Result, RankingExecInfo
from pydantic import BaseModel
from typing import Optional, List, Tuple

# load RankListwiseOSLLM
reranker = RankListwiseOSLLM("Salesforce/SweRankLLM-small")

class RerankRequest(BaseModel):
    query: str
    hits: List[Tuple[int, str]]

class RerankResponse(BaseModel):
    query: str
    hits: List[Tuple[int, str]]

app = FastAPI()

@app.get("/")
def hello_world():
    return {"msg": "Success"}


@app.post("/rerank")
def rerank(request: RerankRequest):
    hits = request.hits
    sorted_hits = sorted(hits, key=lambda x: x[0]) # sort hits again for safety
    
    result = Result(
        query=request.query,
        hits = [{"content": hit[1]} for hit in sorted_hits]
    )
    
    reranked_result = reranker.permutation_pipeline(
        result,
        0, 
        len(hits),
        logging=True
    )

    reranked_hits = [(i + 1, item["content"]) for i, item in enumerate(reranked_result.hits)]

    return {
        "query": request.query,
        "reranked_hits": reranked_hits
    }