WebashalarForML's picture
Update app.py
3426b11 verified
import os
import json
import base64
import numpy as np
from flask import Flask, request, jsonify, render_template
from langchain_experimental.open_clip.open_clip import OpenCLIPEmbeddings
from sklearn.metrics.pairwise import cosine_similarity
from io import BytesIO
from PIL import Image
# from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from io import BytesIO
from pathlib import Path
# ============================== #
# INITIALIZE APP #
# ============================== #
app = Flask(__name__)
clip_embd = OpenCLIPEmbeddings()
BASE_DIR = Path("/app")
BLOCKS_DIR = BASE_DIR / "blocks"
# STATIC_DIR = BASE_DIR / "static"
# GEN_PROJECT_DIR = BASE_DIR / "generated_projects"
BACKDROP_DIR = BLOCKS_DIR / "Backdrops"
SPRITE_DIR = BLOCKS_DIR / "sprites"
CODE_BLOCKS_DIR = BLOCKS_DIR / "code_blocks"
# === new: outputs rooted under BASE_DIR ===
OUTPUT_DIR = BASE_DIR / "outputs"
# ============================== #
# LOAD PRE-COMPUTED EMBEDS #
# ============================== #
with open(f"{BLOCKS_DIR}/embeddings.json", "r") as f:
embedding_json = json.load(f)
image_paths = [item["file-path"] for item in embedding_json]
image_embeds = np.array([item["embeddings"] for item in embedding_json])
# ============================== #
# HELPER: Decode Base64 Image #
# ============================== #
def decode_base64_image(b64_string):
img_data = base64.b64decode(b64_string)
img = Image.open(BytesIO(img_data)).convert("RGB")
return img
# ============================== #
# API ROUTE #
# ============================== #
@app.route("/match", methods=["POST"])
def match_image():
data = request.get_json()
if "images" not in data:
return jsonify({"error": "No images provided"}), 400
results = []
for b64_img in data["images"]:
try:
# Convert Base64 → BytesIO
b_io = BytesIO(base64.b64decode(b64_img))
# Embed the query image
query_embed = np.array(clip_embd.embed_image([b_io]))
# Cosine similarity with stored embeddings
sims = cosine_similarity(query_embed, image_embeds)[0]
best_idx = np.argmax(sims)
results.append({
"input": b64_img[:50] + "...",
"best_match": {
"name": os.path.basename(image_paths[best_idx]),
"path": image_paths[best_idx],
"similarity": float(sims[best_idx])
}
})
except Exception as e:
results.append({"error": str(e)})
return jsonify(results)
# ============================== #
# SIMPLE HTML UI #
# ============================== #
@app.route("/", methods=["GET", "POST"])
def index():
return render_template("index.html")
# ============================== #
# MAIN ENTRY #
# ============================== #
if __name__ == "__main__":
app.run(debug=True, port=7860)