mps-v2 / mps-api.py
huynhdoo's picture
Upload folder using huggingface_hub
a71293a verified
from modal import Image, App, Secret, web_endpoint, Volume, enter, method, build
from typing import Dict
import sys
model_image = (Image.debian_slim(python_version="3.12")
.pip_install("chromadb", "sentence-transformers", "pysqlite3-binary")
)
# Utilities
with model_image.imports():
import os
import numpy as np
__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") # Hotswap SQLlite version
# Application initialization
app = App("mps-api",
image=model_image)
vol = Volume.from_name("mps", create_if_missing=False)
data_path = "/data"
############
# MAIN CLASS
############
@app.cls(timeout=30*60,
volumes={data_path: vol})
class VECTORDB:
@enter()
@build()
def init(self):
# Load encoder
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
model_name = "Lajavaness/sentence-camembert-large"
self.embedding_function = SentenceTransformerEmbeddingFunction(model_name=model_name)
print(f"Embedding model loaded: {model_name}")
# Load vector database
import chromadb
DB_PATH = data_path + "/db"
COLLECTION_NAME = "MPS"
chroma_client = chromadb.PersistentClient(path=DB_PATH)
self.chroma_collection = chroma_client.get_collection(name=COLLECTION_NAME, embedding_function=self.embedding_function)
print(f"{self.chroma_collection.count()} documents loaded.")
@method()
def search(self, queries, origins, n_results=10):
results = self.chroma_collection.query(
query_texts=queries,
n_results=n_results,
where={"origin": {"$in": origins}},
include=['documents', 'metadatas', 'distances'])
documents = results['documents']
metadatas = results['metadatas']
distances = results['distances']
return documents, metadatas, distances
@app.cls(timeout=30*60)
class RANKING:
@enter()
@build()
def init(self):
# Load crossencoder
from sentence_transformers import CrossEncoder
model_name = "Lajavaness/CrossEncoder-camembert-large"
self.cross_encoder = CrossEncoder(model_name)
print(f"Cross encoder model loaded: {model_name}")
@method()
def rank(self, query, documents):
pairs = [[query, doc] for doc in documents]
scores = self.cross_encoder.predict(pairs)
ranking = np.argsort(scores)[::-1].tolist()
return ranking
###########
# ENDPOINTS
###########
@app.function(timeout=30*60)
@web_endpoint(method="POST")
def retrieve(query: Dict):
# Log query
print(f"Retrieve query: {query}...")
# Searching documents
documents, metadatas, distances = VECTORDB().search.remote(query['query'], query['origins'], query['n_results'])
return {"documents" : documents, "metadatas" : metadatas, "distances" : distances}
@app.function(timeout=30*60)
@web_endpoint(method="POST")
def rank(query: Dict):
# Log query
print(f"Rank query: {query}...")
# Ranking documents
ranking = RANKING().rank.remote(query['query'], query['documents'])
return {"ranking" : ranking}