File size: 2,337 Bytes
a3e9331
 
 
 
 
17e53e4
a3e9331
17e53e4
 
 
 
 
a3e9331
 
 
43764a8
a3e9331
 
 
 
 
 
 
43764a8
 
a3e9331
43764a8
a3e9331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17e53e4
 
 
 
 
a3e9331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
MCPyLate Server
A Model Context Protocol server that provides search functionality using PyLate.
"""

# import subprocess

# subprocess.run(
#     "pip install flash-attn --no-build-isolation",
#     env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
#     shell=True,
# )
from typing import Any, Dict, List, Optional

from core import MCPyLate
from huggingface_hub import snapshot_download
from mcp.server.fastmcp import FastMCP


def register_tools(mcp: FastMCP, pylate: MCPyLate):
    """Register all tools with the MCP server."""

    @mcp.tool(
        name="pylate_search_leetcode",
        description="Perform a multi-vector search on the leetcode index. Returns top‑k hits with docid, score, and snippet.",
    )
    def pylate_search_leetcode(
        query: str, k: int = 10, index_name: Optional[str] = None
    ) -> List[Dict[str, Any]]:
        """
        Search the PyLate with multi-vector models and return top-k hits
        Args:
            query: Search query string
            k: Number of results to return (default: 10)
            index_name: Name of index to search (default: use default index)
        Returns:
            List of search results with docid, score, text snippet, and index name
        """
        return pylate.search(query, k)

    @mcp.tool(
        name="get_document",
        description="Retrieve a full document by its document ID from a Pyserini index.",
    )
    def get_document(
        docid: str, index_name: Optional[str] = None
    ) -> Optional[Dict[str, Any]]:
        """
        Retrieve the full text of a document by its ID.

        Args:
            docid: Document ID to retrieve
            index_name: Name of index to search (default: use default index)

        Returns:
            Document with full text, or None if not found
        """
        return pylate.get_document(docid, index_name)


def main():
    """Main entry point for the server."""
    snapshot_download(
        repo_id="lightonai/leetcode_reasonmoderncolbert",
        local_dir="indexess/",
        repo_type="dataset",
    )
    try:
        mcp = FastMCP("pylate-search-server")

        mcpylate = MCPyLate()
        register_tools(mcp, mcpylate)

        mcp.run(transport="stdio")

    except Exception as e:
        print(e)
        raise


if __name__ == "__main__":
    main()