adammkhor commited on
Commit
89bf46f
·
verified ·
1 Parent(s): f51716b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -57
app.py CHANGED
@@ -1,58 +1,60 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- import torch
4
- from transformers import AutoModel, AutoTokenizer
5
- from typing import List
6
-
7
- # --- Configuration ---
8
- EMBEDDING_MODEL_NAME = 'krutrim-ai-labs/vyakyarth'
9
-
10
- # --- Helper Function for Mean Pooling ---
11
- def mean_pooling(model_output, attention_mask):
12
- token_embeddings = model_output[0]
13
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
14
- sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
15
- sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
16
- return sum_embeddings / sum_mask
17
-
18
- # --- FastAPI App Initialization ---
19
- app = FastAPI(title="Embedding Service")
20
-
21
- # --- Load Model on Startup ---
22
- # This dictionary will hold the loaded model and tokenizer
23
- model_payload = {}
24
-
25
- @app.on_event("startup")
26
- def load_model():
27
- """Load the model and tokenizer when the server starts."""
28
- print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}...")
29
- model_payload['tokenizer'] = AutoTokenizer.from_pretrained(EMBEDDING_MODEL_NAME)
30
- model_payload['model'] = AutoModel.from_pretrained(EMBEDDING_MODEL_NAME)
31
- model_payload['model'].eval()
32
- print("Model loaded successfully.")
33
-
34
- # --- Pydantic Models for Request/Response ---
35
- class EmbeddingRequest(BaseModel):
36
- text: str
37
-
38
- class EmbeddingResponse(BaseModel):
39
- embedding: List[float]
40
-
41
- # --- API Endpoint ---
42
- @app.post("/embed", response_model=EmbeddingResponse)
43
- def create_embedding(request: EmbeddingRequest):
44
- """Takes text and returns its vector embedding."""
45
- tokenizer = model_payload['tokenizer']
46
- model = model_payload['model']
47
-
48
- encoded_input = tokenizer(request.text, padding=True, truncation=True, return_tensors='pt')
49
- with torch.no_grad():
50
- model_output = model(**encoded_input)
51
-
52
- embedding = mean_pooling(model_output, encoded_input['attention_mask']).tolist()[0]
53
-
54
- return {"embedding": embedding}
55
-
56
- @app.get("/")
57
- def read_root():
 
 
58
  return {"message": "Embedding Service is running. Use the /embed endpoint."}
 
1
+ import os
2
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache'
3
+ from fastapi import FastAPI
4
+ from pydantic import BaseModel
5
+ import torch
6
+ from transformers import AutoModel, AutoTokenizer
7
+ from typing import List
8
+
9
+ # --- Configuration ---
10
+ EMBEDDING_MODEL_NAME = 'krutrim-ai-labs/vyakyarth'
11
+
12
+ # --- Helper Function for Mean Pooling ---
13
+ def mean_pooling(model_output, attention_mask):
14
+ token_embeddings = model_output[0]
15
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
16
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
17
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
18
+ return sum_embeddings / sum_mask
19
+
20
+ # --- FastAPI App Initialization ---
21
+ app = FastAPI(title="Embedding Service")
22
+
23
+ # --- Load Model on Startup ---
24
+ # This dictionary will hold the loaded model and tokenizer
25
+ model_payload = {}
26
+
27
+ @app.on_event("startup")
28
+ def load_model():
29
+ """Load the model and tokenizer when the server starts."""
30
+ print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}...")
31
+ model_payload['tokenizer'] = AutoTokenizer.from_pretrained(EMBEDDING_MODEL_NAME)
32
+ model_payload['model'] = AutoModel.from_pretrained(EMBEDDING_MODEL_NAME)
33
+ model_payload['model'].eval()
34
+ print("Model loaded successfully.")
35
+
36
+ # --- Pydantic Models for Request/Response ---
37
+ class EmbeddingRequest(BaseModel):
38
+ text: str
39
+
40
+ class EmbeddingResponse(BaseModel):
41
+ embedding: List[float]
42
+
43
+ # --- API Endpoint ---
44
+ @app.post("/embed", response_model=EmbeddingResponse)
45
+ def create_embedding(request: EmbeddingRequest):
46
+ """Takes text and returns its vector embedding."""
47
+ tokenizer = model_payload['tokenizer']
48
+ model = model_payload['model']
49
+
50
+ encoded_input = tokenizer(request.text, padding=True, truncation=True, return_tensors='pt')
51
+ with torch.no_grad():
52
+ model_output = model(**encoded_input)
53
+
54
+ embedding = mean_pooling(model_output, encoded_input['attention_mask']).tolist()[0]
55
+
56
+ return {"embedding": embedding}
57
+
58
+ @app.get("/")
59
+ def read_root():
60
  return {"message": "Embedding Service is running. Use the /embed endpoint."}