import os import json import base64 import numpy as np from flask import Flask, request, jsonify, render_template from langchain_experimental.open_clip.open_clip import OpenCLIPEmbeddings from sklearn.metrics.pairwise import cosine_similarity from io import BytesIO from PIL import Image # from matplotlib.offsetbox import OffsetImage, AnnotationBbox from io import BytesIO from pathlib import Path # ============================== # # INITIALIZE APP # # ============================== # app = Flask(__name__) clip_embd = OpenCLIPEmbeddings() BASE_DIR = Path("/app") BLOCKS_DIR = BASE_DIR / "blocks" # STATIC_DIR = BASE_DIR / "static" # GEN_PROJECT_DIR = BASE_DIR / "generated_projects" BACKDROP_DIR = BLOCKS_DIR / "Backdrops" SPRITE_DIR = BLOCKS_DIR / "sprites" CODE_BLOCKS_DIR = BLOCKS_DIR / "code_blocks" # === new: outputs rooted under BASE_DIR === OUTPUT_DIR = BASE_DIR / "outputs" # ============================== # # LOAD PRE-COMPUTED EMBEDS # # ============================== # with open(f"{BLOCKS_DIR}/embeddings.json", "r") as f: embedding_json = json.load(f) image_paths = [item["file-path"] for item in embedding_json] image_embeds = np.array([item["embeddings"] for item in embedding_json]) # ============================== # # HELPER: Decode Base64 Image # # ============================== # def decode_base64_image(b64_string): img_data = base64.b64decode(b64_string) img = Image.open(BytesIO(img_data)).convert("RGB") return img # ============================== # # API ROUTE # # ============================== # @app.route("/match", methods=["POST"]) def match_image(): data = request.get_json() if "images" not in data: return jsonify({"error": "No images provided"}), 400 results = [] for b64_img in data["images"]: try: # Convert Base64 → BytesIO b_io = BytesIO(base64.b64decode(b64_img)) # Embed the query image query_embed = np.array(clip_embd.embed_image([b_io])) # Cosine similarity with stored embeddings sims = cosine_similarity(query_embed, image_embeds)[0] best_idx = np.argmax(sims) results.append({ "input": b64_img[:50] + "...", "best_match": { "name": os.path.basename(image_paths[best_idx]), "path": image_paths[best_idx], "similarity": float(sims[best_idx]) } }) except Exception as e: results.append({"error": str(e)}) return jsonify(results) # ============================== # # SIMPLE HTML UI # # ============================== # @app.route("/", methods=["GET", "POST"]) def index(): return render_template("index.html") # ============================== # # MAIN ENTRY # # ============================== # if __name__ == "__main__": app.run(debug=True, port=7860)