File size: 585 Bytes
8fed2d7
 
 
 
 
 
ac50ec3
8fed2d7
 
 
 
 
 
f1ddd83
 
 
 
6bdf3ab
f1ddd83
a38553c
 
8fed2d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import torch
from transformers import AutoTokenizer, AutoModel
from src.model.encoder import ProdFeatureEncoder
from src.config.config import ModelConfig

app = FastAPI()

class EmbeddingOutput(BaseModel):
    embedding: List[float]


config = ModelConfig()
model = ProdFeatureEncoder(config=config)

@app.get("/encode_text/{text}", response_model=EmbeddingOutput)
async def encode_text(text: str):
    with torch.no_grad():
        embedding = model(text)
    return {"embedding": embedding.tolist()}