File size: 1,169 Bytes
ede4edf
 
 
 
a6e5f2e
ede4edf
 
eebd998
ede4edf
a6e5f2e
 
ede4edf
 
 
 
 
a6e5f2e
20c7ee4
ede4edf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List
from transformers import pipeline

# Initialize the zero-shot classification pipeline
classifier = pipeline("zero-shot-classification")

# Define the FastAPI application
app = FastAPI()

# Pydantic model for input validation
class ClassificationRequest(BaseModel):
    text: str = Field(..., example="This is a course about the Transformers library")
    labels: List[str] = Field(..., example=["education", "politics", "technology"])

@app.get("/")
def greet_json():
    """
    A simple GET endpoint that returns a greeting message.
    """
    return {"Hello": "World!"}

@app.post("/classify")
def zero_shot_classification(request: ClassificationRequest):
    """
    A POST endpoint that performs zero-shot classification on the input text
    using the provided candidate labels.
    """
    try:
        # Perform zero-shot classification
        result = classifier(
            request.text,
            candidate_labels=request.labels
        )
        return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))