|
from fastapi import FastAPI, Response, HTTPException |
|
from datasets import load_dataset |
|
from io import BytesIO |
|
from PIL import Image |
|
|
|
app = FastAPI() |
|
|
|
dataset = load_dataset("visionLMsftw/vibe-testing-samples", split="train") |
|
id_to_image = {example["ex_id"]: example["image"] for example in dataset} |
|
|
|
print(id_to_image) |
|
|
|
@app.get("/image/{image_id}") |
|
def get_image(image_id: int): |
|
if image_id not in id_to_image: |
|
raise HTTPException(status_code=404, detail="Image not found") |
|
|
|
image: Image.Image = id_to_image[image_id].convert("RGB") |
|
buffer = BytesIO() |
|
image.save(buffer, format="JPEG", quality=85) |
|
buffer.seek(0) |
|
return Response(content=buffer.read(), media_type="image/jpeg") |
|
|
|
@app.get("/ids") |
|
def list_ids(): |
|
return list(id_to_image.keys()) |
|
|
|
|