|
from modal import Image, App, Secret, web_endpoint, Volume, enter, method, build |
|
from typing import Dict |
|
import sys |
|
|
|
model_image = (Image.debian_slim(python_version="3.12") |
|
.pip_install("chromadb", "sentence-transformers", "pysqlite3-binary") |
|
) |
|
|
|
|
|
with model_image.imports(): |
|
import os |
|
import numpy as np |
|
__import__("pysqlite3") |
|
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") |
|
|
|
|
|
app = App("mps-api", |
|
image=model_image) |
|
vol = Volume.from_name("mps", create_if_missing=False) |
|
data_path = "/data" |
|
|
|
|
|
|
|
|
|
@app.cls(timeout=30*60, |
|
volumes={data_path: vol}) |
|
class VECTORDB: |
|
@enter() |
|
@build() |
|
def init(self): |
|
|
|
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction |
|
model_name = "Lajavaness/sentence-camembert-large" |
|
self.embedding_function = SentenceTransformerEmbeddingFunction(model_name=model_name) |
|
print(f"Embedding model loaded: {model_name}") |
|
|
|
|
|
import chromadb |
|
DB_PATH = data_path + "/db" |
|
COLLECTION_NAME = "MPS" |
|
chroma_client = chromadb.PersistentClient(path=DB_PATH) |
|
self.chroma_collection = chroma_client.get_collection(name=COLLECTION_NAME, embedding_function=self.embedding_function) |
|
print(f"{self.chroma_collection.count()} documents loaded.") |
|
|
|
@method() |
|
def search(self, queries, origins, n_results=10): |
|
results = self.chroma_collection.query( |
|
query_texts=queries, |
|
n_results=n_results, |
|
where={"origin": {"$in": origins}}, |
|
include=['documents', 'metadatas', 'distances']) |
|
|
|
documents = results['documents'] |
|
metadatas = results['metadatas'] |
|
distances = results['distances'] |
|
return documents, metadatas, distances |
|
|
|
@app.cls(timeout=30*60) |
|
class RANKING: |
|
@enter() |
|
@build() |
|
def init(self): |
|
|
|
from sentence_transformers import CrossEncoder |
|
model_name = "Lajavaness/CrossEncoder-camembert-large" |
|
self.cross_encoder = CrossEncoder(model_name) |
|
print(f"Cross encoder model loaded: {model_name}") |
|
|
|
@method() |
|
def rank(self, query, documents): |
|
pairs = [[query, doc] for doc in documents] |
|
scores = self.cross_encoder.predict(pairs) |
|
ranking = np.argsort(scores)[::-1].tolist() |
|
return ranking |
|
|
|
|
|
|
|
|
|
@app.function(timeout=30*60) |
|
@web_endpoint(method="POST") |
|
def retrieve(query: Dict): |
|
|
|
print(f"Retrieve query: {query}...") |
|
|
|
|
|
documents, metadatas, distances = VECTORDB().search.remote(query['query'], query['origins'], query['n_results']) |
|
return {"documents" : documents, "metadatas" : metadatas, "distances" : distances} |
|
|
|
@app.function(timeout=30*60) |
|
@web_endpoint(method="POST") |
|
def rank(query: Dict): |
|
|
|
print(f"Rank query: {query}...") |
|
|
|
|
|
ranking = RANKING().rank.remote(query['query'], query['documents']) |
|
|
|
return {"ranking" : ranking} |
|
|