adammkhor commited on
Commit
f51716b
·
verified ·
1 Parent(s): e50ac0f

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +58 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import AutoModel, AutoTokenizer
5
+ from typing import List
6
+
7
+ # --- Configuration ---
8
+ EMBEDDING_MODEL_NAME = 'krutrim-ai-labs/vyakyarth'
9
+
10
+ # --- Helper Function for Mean Pooling ---
11
+ def mean_pooling(model_output, attention_mask):
12
+ token_embeddings = model_output[0]
13
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
14
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
15
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
16
+ return sum_embeddings / sum_mask
17
+
18
+ # --- FastAPI App Initialization ---
19
+ app = FastAPI(title="Embedding Service")
20
+
21
+ # --- Load Model on Startup ---
22
+ # This dictionary will hold the loaded model and tokenizer
23
+ model_payload = {}
24
+
25
+ @app.on_event("startup")
26
+ def load_model():
27
+ """Load the model and tokenizer when the server starts."""
28
+ print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}...")
29
+ model_payload['tokenizer'] = AutoTokenizer.from_pretrained(EMBEDDING_MODEL_NAME)
30
+ model_payload['model'] = AutoModel.from_pretrained(EMBEDDING_MODEL_NAME)
31
+ model_payload['model'].eval()
32
+ print("Model loaded successfully.")
33
+
34
+ # --- Pydantic Models for Request/Response ---
35
+ class EmbeddingRequest(BaseModel):
36
+ text: str
37
+
38
+ class EmbeddingResponse(BaseModel):
39
+ embedding: List[float]
40
+
41
+ # --- API Endpoint ---
42
+ @app.post("/embed", response_model=EmbeddingResponse)
43
+ def create_embedding(request: EmbeddingRequest):
44
+ """Takes text and returns its vector embedding."""
45
+ tokenizer = model_payload['tokenizer']
46
+ model = model_payload['model']
47
+
48
+ encoded_input = tokenizer(request.text, padding=True, truncation=True, return_tensors='pt')
49
+ with torch.no_grad():
50
+ model_output = model(**encoded_input)
51
+
52
+ embedding = mean_pooling(model_output, encoded_input['attention_mask']).tolist()[0]
53
+
54
+ return {"embedding": embedding}
55
+
56
+ @app.get("/")
57
+ def read_root():
58
+ return {"message": "Embedding Service is running. Use the /embed endpoint."}
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ transformers
5
+ pydantic