fashxp commited on
Commit
8f80642
·
1 Parent(s): afab4d9

added embedding endpoints

Browse files
Files changed (5) hide show
  1. Dockerfile +8 -2
  2. docker-compose.yaml +15 -2
  3. requirements.txt +1 -0
  4. src/embeddings.py +356 -0
  5. src/main.py +193 -1
Dockerfile CHANGED
@@ -4,14 +4,20 @@ RUN useradd -m -u 1000 user
4
  USER user
5
 
6
  ENV HOME=/home/user \
7
- PATH=/home/user/.local/bin:$PATH
 
 
8
 
9
  WORKDIR $HOME/app
10
 
 
11
  COPY --chown=user requirements.txt requirements.txt
12
 
13
- RUN pip install --upgrade -r requirements.txt
 
 
14
 
 
15
  COPY --chown=user . .
16
 
17
  CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "7860"]
 
4
  USER user
5
 
6
  ENV HOME=/home/user \
7
+ PATH=/home/user/.local/bin:$PATH \
8
+ PYTHONDONTWRITEBYTECODE=1 \
9
+ PYTHONUNBUFFERED=1
10
 
11
  WORKDIR $HOME/app
12
 
13
+ # Copy requirements first for better caching
14
  COPY --chown=user requirements.txt requirements.txt
15
 
16
+ # Install dependencies with caching
17
+ RUN pip install --upgrade pip && \
18
+ pip install --no-cache-dir --user -r requirements.txt
19
 
20
+ # Copy application code
21
  COPY --chown=user . .
22
 
23
  CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "7860"]
docker-compose.yaml CHANGED
@@ -2,14 +2,27 @@ services:
2
  server:
3
  build:
4
  context: .
 
 
 
5
  ports:
6
  - 7860:7860
7
  develop:
8
  watch:
 
9
  - action: rebuild
10
- path: .
 
 
 
 
 
 
11
  volumes:
12
  - python-cache:/home/user/.cache
 
 
13
 
14
  volumes:
15
- python-cache:
 
 
2
  server:
3
  build:
4
  context: .
5
+ # Enable BuildKit for better caching
6
+ cache_from:
7
+ - python:3.9
8
  ports:
9
  - 7860:7860
10
  develop:
11
  watch:
12
+ # Only rebuild on requirements.txt changes, sync code changes otherwise
13
  - action: rebuild
14
+ path: ./requirements.txt
15
+ - action: sync
16
+ path: ./src
17
+ target: /home/user/app/src
18
+ - action: sync
19
+ path: ./README.md
20
+ target: /home/user/app/README.md
21
  volumes:
22
  - python-cache:/home/user/.cache
23
+ # Cache pip packages
24
+ - pip-cache:/home/user/.cache/pip
25
 
26
  volumes:
27
+ python-cache:
28
+ pip-cache:
requirements.txt CHANGED
@@ -6,4 +6,5 @@ sentencepiece
6
  sacremoses
7
  torch
8
  pillow
 
9
  # Optional dependencies for specific features
 
6
  sacremoses
7
  torch
8
  pillow
9
+ protobuf
10
  # Optional dependencies for specific features
src/embeddings.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------------------------------------------
2
+ # This source file is available under the terms of the
3
+ # Pimcore Open Core License (POCL)
4
+ # Full copyright and license information is available in
5
+ # LICENSE.md which is distributed with this source code.
6
+ #
7
+ # @copyright Copyright (c) Pimcore GmbH (https://www.pimcore.com)
8
+ # @license Pimcore Open Core License (POCL)
9
+ # -------------------------------------------------------------------
10
+
11
+ import torch
12
+ import base64
13
+ import io
14
+ import logging
15
+ from PIL import Image
16
+ from pydantic import BaseModel
17
+ from fastapi import Request, HTTPException
18
+ import json
19
+ from typing import Optional, Union, Dict, Any
20
+ from transformers import AutoProcessor, AutoModel
21
+
22
+
23
+ class EmbeddingRequest(BaseModel):
24
+ inputs: str
25
+ parameters: Optional[dict] = None
26
+
27
+
28
+ class BaseEmbeddingTaskService:
29
+ """Base class for embedding services with common functionality"""
30
+
31
+ def __init__(self, logger: logging.Logger):
32
+ self._logger = logger
33
+ self._model_cache = {}
34
+ self._processor_cache = {}
35
+
36
+ async def get_embedding_request(self, request: Request) -> EmbeddingRequest:
37
+ """Parse request body into EmbeddingRequest"""
38
+ content_type = request.headers.get("content-type", "")
39
+ if content_type.startswith("application/json"):
40
+ data = await request.json()
41
+ return EmbeddingRequest(**data)
42
+ if content_type.startswith("application/x-www-form-urlencoded"):
43
+ raw = await request.body()
44
+ try:
45
+ data = json.loads(raw)
46
+ return EmbeddingRequest(**data)
47
+ except Exception:
48
+ try:
49
+ data = json.loads(raw.decode("utf-8"))
50
+ return EmbeddingRequest(**data)
51
+ except Exception:
52
+ raise HTTPException(status_code=400, detail="Invalid request body")
53
+ raise HTTPException(status_code=400, detail="Unsupported content type")
54
+
55
+ def _get_device(self) -> torch.device:
56
+ """Get the appropriate device (GPU if available, otherwise CPU)"""
57
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ self._logger.info(f"Using device: {device}")
59
+ return device
60
+
61
+ def _load_processor(self, model_name: str):
62
+ """Load and cache processor for the model using AutoProcessor"""
63
+ if model_name not in self._processor_cache:
64
+ try:
65
+ self._processor_cache[model_name] = AutoProcessor.from_pretrained(model_name)
66
+ self._logger.info(f"Loaded processor for model: {model_name}")
67
+ except Exception as e:
68
+ self._logger.error(f"Failed to load processor for model '{model_name}': {str(e)}")
69
+ raise HTTPException(
70
+ status_code=404,
71
+ detail=f"Processor for model '{model_name}' could not be loaded: {str(e)}"
72
+ )
73
+ return self._processor_cache[model_name]
74
+
75
+ def _load_model(self, model_name: str, cache_suffix: str = ""):
76
+ """Load and cache model using AutoModel"""
77
+ cache_key = f"{model_name}{cache_suffix}"
78
+ if cache_key not in self._model_cache:
79
+ try:
80
+ device = self._get_device()
81
+ model = AutoModel.from_pretrained(model_name)
82
+ model.to(device)
83
+ self._model_cache[cache_key] = model
84
+ self._logger.info(f"Loaded model: {model_name} on {device}")
85
+ except Exception as e:
86
+ self._logger.error(f"Failed to load model '{model_name}': {str(e)}")
87
+ raise HTTPException(
88
+ status_code=404,
89
+ detail=f"Model '{model_name}' could not be loaded: {str(e)}"
90
+ )
91
+ return self._model_cache[cache_key]
92
+
93
+ async def get_embedding_vector_size(self, model_name: str) -> dict:
94
+ """Get the vector size of embeddings for a given model"""
95
+ try:
96
+ # Load the model to get its configuration
97
+ model = self._load_model(model_name)
98
+
99
+ # Try to get the embedding dimension from the model configuration
100
+ used_attribute = None
101
+ if hasattr(model.config, 'hidden_size'):
102
+ vector_size = model.config.hidden_size
103
+ used_attribute = "hidden_size"
104
+ elif hasattr(model.config, 'projection_dim'):
105
+ vector_size = model.config.projection_dim
106
+ used_attribute = "projection_dim"
107
+ elif hasattr(model.config, 'd_model'):
108
+ vector_size = model.config.d_model
109
+ used_attribute = "d_model"
110
+ elif hasattr(model.config, 'text_config') and hasattr(model.config.text_config, 'hidden_size'):
111
+ vector_size = model.config.text_config.hidden_size
112
+ used_attribute = "text_config.hidden_size"
113
+ elif hasattr(model.config, 'vision_config') and hasattr(model.config.vision_config, 'hidden_size'):
114
+ vector_size = model.config.vision_config.hidden_size
115
+ used_attribute = "vision_config.hidden_size"
116
+ else:
117
+ # If we can't determine from config, we'll need to run a dummy inference
118
+ raise AttributeError("Could not determine vector size from model configuration")
119
+
120
+ self._logger.info(f"Model {model_name} has embedding vector size: {vector_size}")
121
+ return {
122
+ "model_name": model_name,
123
+ "vector_size": vector_size,
124
+ "config_attribute_used": used_attribute
125
+ }
126
+
127
+ except Exception as e:
128
+ self._logger.error(f"Failed to get vector size for model '{model_name}': {str(e)}")
129
+ raise HTTPException(
130
+ status_code=404,
131
+ detail=f"Could not determine vector size for model '{model_name}': {str(e)}"
132
+ )
133
+
134
+ def _extract_embeddings(self, model_output, model_name: str) -> torch.Tensor:
135
+ """Extract embeddings from model output with fallback strategies"""
136
+
137
+ # Try different embedding extraction methods in order of preference
138
+
139
+ # 1. Check for pooler_output (most common)
140
+ if hasattr(model_output, 'pooler_output') and model_output.pooler_output is not None:
141
+ self._logger.debug(f"Using pooler_output for {model_name}")
142
+ return model_output.pooler_output
143
+
144
+ # 2. Check for last_hidden_state and pool it
145
+ if hasattr(model_output, 'last_hidden_state') and model_output.last_hidden_state is not None:
146
+ self._logger.debug(f"Using pooled last_hidden_state for {model_name}")
147
+ # Mean pooling over sequence dimension
148
+ return model_output.last_hidden_state.mean(dim=1)
149
+
150
+ # 3. Check for image_embeds (CLIP-style models)
151
+ if hasattr(model_output, 'image_embeds') and model_output.image_embeds is not None:
152
+ self._logger.debug(f"Using image_embeds for {model_name}")
153
+ return model_output.image_embeds
154
+
155
+ # 4. Check for text_embeds (CLIP-style models)
156
+ if hasattr(model_output, 'text_embeds') and model_output.text_embeds is not None:
157
+ self._logger.debug(f"Using text_embeds for {model_name}")
158
+ return model_output.text_embeds
159
+
160
+ # 5. Fallback: try to use the output directly if it's a tensor
161
+ if isinstance(model_output, torch.Tensor):
162
+ self._logger.debug(f"Using direct tensor output for {model_name}")
163
+ return model_output
164
+
165
+ # 6. Last resort: check if output is a tuple and use the first element
166
+ if isinstance(model_output, tuple) and len(model_output) > 0:
167
+ self._logger.debug(f"Using first element of tuple output for {model_name}")
168
+ return model_output[0]
169
+
170
+ # If none of the above work, raise an error
171
+ raise HTTPException(
172
+ status_code=500,
173
+ detail=f"Could not extract embeddings from model output for {model_name}. "
174
+ f"Available attributes: {dir(model_output) if hasattr(model_output, '__dict__') else 'Unknown'}"
175
+ )
176
+
177
+
178
+ class ImageEmbeddingTaskService(BaseEmbeddingTaskService):
179
+ """Service for generating image embeddings"""
180
+
181
+ def _decode_base64_image(self, base64_string: str) -> Image.Image:
182
+ """Decode base64 string to PIL Image"""
183
+ try:
184
+ # Remove data URL prefix if present
185
+ if base64_string.startswith('data:image'):
186
+ base64_string = base64_string.split(',')[1]
187
+
188
+ image_data = base64.b64decode(base64_string)
189
+ image = Image.open(io.BytesIO(image_data))
190
+
191
+ # Convert to RGB if necessary
192
+ if image.mode != 'RGB':
193
+ image = image.convert('RGB')
194
+
195
+ return image
196
+ except Exception as e:
197
+ raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}")
198
+
199
+ def _generate_image_embeddings(self, image: Image.Image, model, processor, model_name: str) -> list:
200
+ """Generate embeddings for an image"""
201
+ device = self._get_device()
202
+
203
+ # Process the image
204
+ inputs = processor(images=image, return_tensors="pt", padding=True)
205
+
206
+ # Move inputs to the same device as the model
207
+ inputs = {k: v.to(device) for k, v in inputs.items()}
208
+
209
+ # Get the embeddings
210
+ with torch.no_grad():
211
+ # Try using specialized methods first for CLIP-like models
212
+ if hasattr(model, 'get_image_features'):
213
+ self._logger.debug(f"Using get_image_features for {model_name}")
214
+ embeddings = model.get_image_features(pixel_values=inputs.get('pixel_values'))
215
+ elif hasattr(model, 'vision_model'):
216
+ self._logger.debug(f"Using vision_model for {model_name}")
217
+ vision_outputs = model.vision_model(**inputs)
218
+ embeddings = self._extract_embeddings(vision_outputs, model_name)
219
+ else:
220
+ self._logger.debug(f"Using full model for {model_name}")
221
+ outputs = model(**inputs)
222
+ embeddings = self._extract_embeddings(outputs, model_name)
223
+
224
+ self._logger.info(f"Image embedding shape: {embeddings.shape}")
225
+
226
+ # Move back to CPU before converting to numpy
227
+ embeddings_array = embeddings.cpu().numpy()
228
+
229
+ return embeddings_array[0].tolist()
230
+
231
+ async def generate_embedding(self, request: Request, model_name: str):
232
+ """Main method to generate image embeddings"""
233
+ embedding_request: EmbeddingRequest = await self.get_embedding_request(request)
234
+
235
+ self._logger.info(f"Generating image embedding for model: {model_name}")
236
+
237
+ # Load processor and model using auto-detection
238
+ processor = self._load_processor(model_name)
239
+ model = self._load_model(model_name, "_image")
240
+
241
+ # Decode image from base64
242
+ image = self._decode_base64_image(embedding_request.inputs)
243
+
244
+ try:
245
+ # Generate embeddings
246
+ embeddings = self._generate_image_embeddings(image, model, processor, model_name)
247
+
248
+ self._logger.info("Image embedding generation completed")
249
+ return {"embeddings": embeddings}
250
+
251
+ except Exception as e:
252
+ self._logger.error(f"Embedding generation failed for model '{model_name}': {str(e)}")
253
+ raise HTTPException(
254
+ status_code=500,
255
+ detail=f"Embedding generation failed: {str(e)}"
256
+ )
257
+
258
+ async def generate_embedding_from_upload(self, uploaded_file, model_name: str):
259
+ """Generate image embeddings from uploaded file"""
260
+ from fastapi import UploadFile
261
+
262
+ self._logger.info(f"Generating image embedding from uploaded file for model: {model_name}")
263
+
264
+ # Validate file type
265
+ if not uploaded_file.content_type.startswith('image/'):
266
+ raise HTTPException(
267
+ status_code=400,
268
+ detail=f"Invalid file type: {uploaded_file.content_type}. Only image files are supported."
269
+ )
270
+
271
+ try:
272
+ # Read file content
273
+ file_content = await uploaded_file.read()
274
+
275
+ # Convert to PIL Image
276
+ image = Image.open(io.BytesIO(file_content)).convert('RGB')
277
+
278
+ # Load processor and model using auto-detection
279
+ processor = self._load_processor(model_name)
280
+ model = self._load_model(model_name, "_image")
281
+
282
+ # Generate embeddings
283
+ embeddings = self._generate_image_embeddings(image, model, processor, model_name)
284
+
285
+ self._logger.info("Image embedding generation from upload completed")
286
+ return {"embeddings": embeddings}
287
+
288
+ except Exception as e:
289
+ self._logger.error(f"Embedding generation from upload failed for model '{model_name}': {str(e)}")
290
+ raise HTTPException(
291
+ status_code=500,
292
+ detail=f"Embedding generation from upload failed: {str(e)}"
293
+ )
294
+
295
+
296
+ class TextEmbeddingTaskService(BaseEmbeddingTaskService):
297
+ """Service for generating text embeddings"""
298
+
299
+ def _generate_text_embeddings(self, text: str, model, processor, model_name: str) -> list:
300
+ """Generate embeddings for text"""
301
+ device = self._get_device()
302
+
303
+ # Process the text
304
+ inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
305
+
306
+ # Move inputs to the same device as the model
307
+ inputs = {k: v.to(device) for k, v in inputs.items()}
308
+
309
+ # Get the embeddings
310
+ with torch.no_grad():
311
+ # Try using specialized methods first for CLIP-like models
312
+ if hasattr(model, 'get_text_features'):
313
+ self._logger.debug(f"Using get_text_features for {model_name}")
314
+ embeddings = model.get_text_features(
315
+ input_ids=inputs.get('input_ids'),
316
+ attention_mask=inputs.get('attention_mask')
317
+ )
318
+ elif hasattr(model, 'text_model'):
319
+ self._logger.debug(f"Using text_model for {model_name}")
320
+ text_outputs = model.text_model(**inputs)
321
+ embeddings = self._extract_embeddings(text_outputs, model_name)
322
+ else:
323
+ self._logger.debug(f"Using full model for {model_name}")
324
+ outputs = model(**inputs)
325
+ embeddings = self._extract_embeddings(outputs, model_name)
326
+
327
+ self._logger.info(f"Text embedding shape: {embeddings.shape}")
328
+
329
+ # Move back to CPU before converting to numpy
330
+ embeddings_array = embeddings.cpu().numpy()
331
+
332
+ return embeddings_array[0].tolist()
333
+
334
+ async def generate_embedding(self, request: Request, model_name: str):
335
+ """Main method to generate text embeddings"""
336
+ embedding_request: EmbeddingRequest = await self.get_embedding_request(request)
337
+
338
+ self._logger.info(f"Generating text embedding for: {embedding_request.inputs[:50]}...")
339
+
340
+ # Load processor and model using auto-detection
341
+ processor = self._load_processor(model_name)
342
+ model = self._load_model(model_name, "_text")
343
+
344
+ try:
345
+ # Generate embeddings
346
+ embeddings = self._generate_text_embeddings(embedding_request.inputs, model, processor, model_name)
347
+
348
+ self._logger.info("Text embedding generation completed")
349
+ return {"embeddings": embeddings}
350
+
351
+ except Exception as e:
352
+ self._logger.error(f"Embedding generation failed for model '{model_name}': {str(e)}")
353
+ raise HTTPException(
354
+ status_code=500,
355
+ detail=f"Embedding generation failed: {str(e)}"
356
+ )
src/main.py CHANGED
@@ -10,13 +10,14 @@
10
 
11
  import torch
12
 
13
- from fastapi import FastAPI, Path, Request
14
  import logging
15
  import sys
16
 
17
  from .translation_task import TranslationTaskService
18
  from .classification import ClassificationTaskService
19
  from .text_to_image import TextToImageTaskService
 
20
 
21
  app = FastAPI(
22
  title="Pimcore Local Inference Service",
@@ -294,3 +295,194 @@ async def image_to_text(
294
  model_name = model_name.rstrip("/")
295
  imageToTextTask = TextToImageTaskService(logger)
296
  return await imageToTextTask.extract(request, model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  import torch
12
 
13
+ from fastapi import FastAPI, Path, Request, File, UploadFile
14
  import logging
15
  import sys
16
 
17
  from .translation_task import TranslationTaskService
18
  from .classification import ClassificationTaskService
19
  from .text_to_image import TextToImageTaskService
20
+ from .embeddings import ImageEmbeddingTaskService, TextEmbeddingTaskService
21
 
22
  app = FastAPI(
23
  title="Pimcore Local Inference Service",
 
295
  model_name = model_name.rstrip("/")
296
  imageToTextTask = TextToImageTaskService(logger)
297
  return await imageToTextTask.extract(request, model_name)
298
+
299
+
300
+ # =========================
301
+ # Image Embedding Task
302
+ # =========================
303
+ @app.post(
304
+ "/image-embedding/{model_name:path}",
305
+ openapi_extra={
306
+ "requestBody": {
307
+ "content": {
308
+ "application/json": {
309
+ "example": {
310
+ "inputs": "base64_encoded_image_string"
311
+ }
312
+ }
313
+ }
314
+ }
315
+ }
316
+ )
317
+ async def image_embedding(
318
+ request: Request,
319
+ model_name: str = Path(
320
+ ...,
321
+ description="The name of the image embedding model. Supported models include: google/siglip-so400m-patch14-384, openai/clip-vit-large-patch14, openai/clip-vit-base-patch16, laion/CLIP-ViT-bigG-14-laion2B-39B-b160k, Salesforce/blip-itm-large-flickr",
322
+ example="google/siglip-so400m-patch14-384"
323
+ )
324
+ ):
325
+ """
326
+ Generate embedding vectors for image data.
327
+
328
+ The service supports multiple model types including SigLIP, CLIP, and BLIP models.
329
+ Returns a dense vector representation of the input image.
330
+
331
+ Returns:
332
+ list: The embedding vector as a list of float values.
333
+ """
334
+
335
+ model_name = model_name.rstrip("/")
336
+ imageEmbeddingTask = ImageEmbeddingTaskService(logger)
337
+ return await imageEmbeddingTask.generate_embedding(request, model_name)
338
+
339
+
340
+ # =========================
341
+ # Image Embedding Upload Task (Development/Testing)
342
+ # =========================
343
+ @app.post(
344
+ "/image-embedding-upload/{model_name:path}",
345
+ openapi_extra={
346
+ "requestBody": {
347
+ "content": {
348
+ "multipart/form-data": {
349
+ "schema": {
350
+ "type": "object",
351
+ "properties": {
352
+ "image": {
353
+ "type": "string",
354
+ "format": "binary",
355
+ "description": "Image file to upload for embedding generation"
356
+ }
357
+ },
358
+ "required": ["image"]
359
+ }
360
+ }
361
+ }
362
+ },
363
+ "responses": {
364
+ "200": {
365
+ "description": "Image embedding vector",
366
+ "content": {
367
+ "application/json": {
368
+ "example": {
369
+ "embeddings": [0.1, -0.2, 0.3, "..."]
370
+ }
371
+ }
372
+ }
373
+ }
374
+ }
375
+ }
376
+ )
377
+ async def image_embedding_upload(
378
+ image: UploadFile = File(..., description="Image file to generate embeddings for"),
379
+ model_name: str = Path(
380
+ ...,
381
+ description="The name of the image embedding model. Supported models include: google/siglip-so400m-patch14-384, openai/clip-vit-large-patch14, openai/clip-vit-base-patch16, laion/CLIP-ViT-bigG-14-laion2B-39B-b160k, Salesforce/blip-itm-large-flickr",
382
+ example="google/siglip-so400m-patch14-384"
383
+ )
384
+ ):
385
+ """
386
+ Generate embedding vectors for uploaded image data (Development/Testing endpoint).
387
+
388
+ This endpoint allows you to upload an image file directly through the Swagger UI
389
+ for development and testing purposes. The image is processed and converted to
390
+ embedding vectors using the specified model.
391
+
392
+ Supported formats: JPEG, PNG, GIF, BMP, TIFF
393
+
394
+ The service supports multiple model types including SigLIP, CLIP, and BLIP models.
395
+ Returns a dense vector representation of the uploaded image.
396
+
397
+ Returns:
398
+ dict: The embedding vector as a list of float values.
399
+ """
400
+
401
+ model_name = model_name.rstrip("/")
402
+ imageEmbeddingTask = ImageEmbeddingTaskService(logger)
403
+ return await imageEmbeddingTask.generate_embedding_from_upload(image, model_name)
404
+
405
+
406
+ # =========================
407
+ # Text Embedding Task
408
+ # =========================
409
+ @app.post(
410
+ "/text-embedding/{model_name:path}",
411
+ openapi_extra={
412
+ "requestBody": {
413
+ "content": {
414
+ "application/json": {
415
+ "example": {
416
+ "inputs": "text to embed"
417
+ }
418
+ }
419
+ }
420
+ }
421
+ }
422
+ )
423
+ async def text_embedding(
424
+ request: Request,
425
+ model_name: str = Path(
426
+ ...,
427
+ description="The name of the text embedding model. Supported models include: google/siglip-so400m-patch14-384, openai/clip-vit-large-patch14, openai/clip-vit-base-patch16, laion/CLIP-ViT-bigG-14-laion2B-39B-b160k, Salesforce/blip-itm-large-flickr",
428
+ example="google/siglip-so400m-patch14-384"
429
+ )
430
+ ):
431
+ """
432
+ Generate embedding vectors for text data.
433
+
434
+ The service supports multiple model types including SigLIP, CLIP, and BLIP models.
435
+ Returns a dense vector representation of the input text.
436
+
437
+ Returns:
438
+ list: The embedding vector as a list of float values.
439
+ """
440
+
441
+ model_name = model_name.rstrip("/")
442
+ textEmbeddingTask = TextEmbeddingTaskService(logger)
443
+ return await textEmbeddingTask.generate_embedding(request, model_name)
444
+
445
+
446
+ # =========================
447
+ # Embedding Vector Size
448
+ # =========================
449
+ @app.get(
450
+ "/embedding-vector-size/{model_name:path}",
451
+ openapi_extra={
452
+ "responses": {
453
+ "200": {
454
+ "description": "Vector size information",
455
+ "content": {
456
+ "application/json": {
457
+ "example": {
458
+ "model_name": "google/siglip-so400m-patch14-384",
459
+ "vector_size": 1152,
460
+ "config_attribute_used": "hidden_size"
461
+ }
462
+ }
463
+ }
464
+ }
465
+ }
466
+ }
467
+ )
468
+ async def embedding_vector_size(
469
+ model_name: str = Path(
470
+ ...,
471
+ description="The name of the embedding model. Supported models include: google/siglip-so400m-patch14-384, openai/clip-vit-large-patch14, openai/clip-vit-base-patch16, laion/CLIP-ViT-bigG-14-laion2B-39B-b160k, Salesforce/blip-itm-large-flickr",
472
+ example="google/siglip-so400m-patch14-384"
473
+ )
474
+ ):
475
+ """
476
+ Get the vector size of embeddings for a given model.
477
+
478
+ This endpoint returns the dimensionality of the embedding vectors that the model produces.
479
+ Useful for understanding the output format before generating embeddings.
480
+
481
+ Returns:
482
+ dict: Information about the vector size including model name, vector size, and configuration attribute used.
483
+ """
484
+
485
+ model_name = model_name.rstrip("/")
486
+ # We can use either ImageEmbeddingTaskService or TextEmbeddingTaskService as they inherit from the same base class
487
+ embeddingTask = ImageEmbeddingTaskService(logger)
488
+ return await embeddingTask.get_embedding_vector_size(model_name)