mrinjera commited on
Commit
2895d63
·
verified ·
1 Parent(s): 08d77f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -11,6 +11,7 @@ class RerankRequest(BaseModel):
11
  hits: List[Tuple[int, str]]
12
 
13
  class RerankResponse(BaseModel):
 
14
  hits: List[Tuple[int, str]]
15
 
16
  app = FastAPI()
@@ -21,13 +22,13 @@ def hello_world():
21
 
22
 
23
  @app.get("/rerank")
24
- def rerank(request: RerankRequest):
25
  hits = request.hits
26
  sorted_hits = sorted(hits, key=lambda x: x[0]) # sort hits again for safety
27
 
28
  result = Result(
29
- query=request["query"],
30
- hits = [{"content": hit} for hit in sorted_hits]
31
  )
32
 
33
  reranked_result = reranker.permutation_pipeline(
@@ -37,8 +38,11 @@ def rerank(request: RerankRequest):
37
  logging=True
38
  )
39
 
40
- response = [(i, item["content"]) for i, item in enumerate(reranked_result.hits)]
41
 
42
- return {"reranked": response}
 
 
 
43
 
44
 
 
11
  hits: List[Tuple[int, str]]
12
 
13
  class RerankResponse(BaseModel):
14
+ query: str
15
  hits: List[Tuple[int, str]]
16
 
17
  app = FastAPI()
 
22
 
23
 
24
  @app.get("/rerank")
25
+ def rerank(request: RerankRequest) -> RerankResponse:
26
  hits = request.hits
27
  sorted_hits = sorted(hits, key=lambda x: x[0]) # sort hits again for safety
28
 
29
  result = Result(
30
+ query=request.query,
31
+ hits = [{"content": hit[1]} for hit in sorted_hits]
32
  )
33
 
34
  reranked_result = reranker.permutation_pipeline(
 
38
  logging=True
39
  )
40
 
41
+ reranked_hits = [(i + 1, item["content"]) for i, item in enumerate(reranked_result.hits)]
42
 
43
+ return {
44
+ "query": request.query,
45
+ "reranked_hits": reranked_hits
46
+ }
47
 
48