import modal
import logging
app = modal.App("qwen-reranker-vllm")
hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol")
vllm_cache_vol = modal.Volume.from_name("vllm-cache")
MINUTES = 60  # seconds
vllm_image = (
    modal.Image.debian_slim(python_version="3.12")
    .pip_install(
        "vllm==0.8.5",
        "transformers",
        "torch",
        "fastapi[all]",
        "pydantic"
    )
    .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
)
with vllm_image.imports():
    from transformers import AutoTokenizer
    from vllm import LLM, SamplingParams
    from vllm.inputs.data import TokensPrompt
    import torch
    import math
@app.cls(image=vllm_image, 
         gpu="A100-40GB", 
         scaledown_window=15 * MINUTES,  # how long should we stay up with no requests?
         timeout=10 * MINUTES,
         volumes = {
    "/root/.cache/huggingface":hf_cache_vol,
    "/root/.cache/vllm": vllm_cache_vol,
})
class Reranker:
    @modal.enter()
    def load_reranker(self):
        logging.info("in the rank function")
        self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B")
        self.tokenizer.padding_side = "left"
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = LLM(
            model="Qwen/Qwen3-Reranker-4B",
            tensor_parallel_size=torch.cuda.device_count(),
            max_model_len=10000,
            enable_prefix_caching=True,
            gpu_memory_utilization=0.8
        )
        self.suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"
        self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
        self.max_length = 8192
        self.true_token = self.tokenizer("yes", add_special_tokens=False).input_ids[0]
        self.false_token = self.tokenizer("no", add_special_tokens=False).input_ids[0]
        self.sampling_params = SamplingParams(
            temperature=0,
            max_tokens=1,
            logprobs=20,
            allowed_token_ids=[self.true_token, self.false_token],
        )
    def format_instruction(self, instruction, query, doc):
        return [
            {"role": "system", "content": "Judge whether the Table will be usefull to create an sql request to answer the Query. Note that the answer can only be \"yes\" or \"no\""},
            {"role": "user", "content": f": {instruction}\n\n: {query}\n\n: {doc}"}
        ]
    def process_inputs(self,pairs, instruction):
        messages = [self.format_instruction(instruction, query, doc) for query, doc in pairs]
        messages =  self.tokenizer.apply_chat_template(
            messages, tokenize=True, add_generation_prompt=False, enable_thinking=False
        )
        messages = [ele[:self.max_length] + self.suffix_tokens for ele in messages]
        messages = [TokensPrompt(prompt_token_ids=ele) for ele in messages]
        return messages
    def compute_logits(self, messages):
        outputs = self.model.generate(messages, self.sampling_params, use_tqdm=False)
        scores = []
        for i in range(len(outputs)):
            final_logits = outputs[i].outputs[0].logprobs[-1]
            token_count = len(outputs[i].outputs[0].token_ids)
            if self.true_token not in final_logits:
                true_logit = -10
            else:
                true_logit = final_logits[self.true_token].logprob
            if self.false_token not in final_logits:
                false_logit = -10
            else:
                false_logit = final_logits[self.false_token].logprob
            true_score = math.exp(true_logit)
            false_score = math.exp(false_logit)
            score = true_score / (true_score + false_score)
            scores.append(score)
        return scores
    @modal.method()
    def rerank(self, query, documents,task):
        #task = 'Given a web search query, retrieve relevant passages that answer the query'
        pairs = [(query, doc) for doc in documents]
        inputs = self.process_inputs(pairs, task)
        scores = self.compute_logits( inputs)
        return [{"score": float(score), "content": doc} for score, doc in zip(scores, documents)]
@app.function(
    image=modal.Image.debian_slim(python_version="3.12")
    .pip_install("fastapi[standard]==0.115.4","pydantic")
)
@modal.asgi_app(label="rerank-endpoint")
def fastapi_app():
    from pydantic import BaseModel
    from fastapi import FastAPI, Request, Response
    from fastapi.responses import JSONResponse
    from typing import List
    web_app = FastAPI()
    reranker = Reranker()
    class ScoringResult(BaseModel):
        score: float
        content: str
    class RankingRequest(BaseModel):
        task:str
        query: str
        documents: List[str]
    @web_app.post("/rank",response_model=List[ScoringResult])
    async def predict(payload: RankingRequest):
        logging.info("call the rank function")
        query = payload.query
        documents = payload.documents
        task = payload.task
        output_data = reranker.rerank.remote(query,documents,task)
        return JSONResponse(content=output_data)
    return web_app