Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import base64 | |
| import os | |
| import pickle | |
| import time | |
| from typing import Dict, List | |
| import cv2 | |
| import numpy as np | |
| import requests | |
| ENDPOINT = "http://127.0.0.1:8000" | |
| if "REMOTE_URL_RAILWAY" in os.environ: | |
| ENDPOINT = os.environ["REMOTE_URL_RAILWAY"] | |
| print(f"API ENDPOINT: {ENDPOINT}") | |
| API_VERSION = f"{ENDPOINT}/version" | |
| API_URL_MATCH = f"{ENDPOINT}/v1/match" | |
| API_URL_EXTRACT = f"{ENDPOINT}/v1/extract" | |
| def read_image(path: str) -> str: | |
| """ | |
| Read an image from a file, encode it as a JPEG and then as a base64 string. | |
| Args: | |
| path (str): The path to the image to read. | |
| Returns: | |
| str: The base64 encoded image. | |
| """ | |
| # Read the image from the file | |
| img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) | |
| # Encode the image as a png, NO COMPRESSION!!! | |
| retval, buffer = cv2.imencode(".png", img) | |
| # Encode the JPEG as a base64 string | |
| b64img = base64.b64encode(buffer).decode("utf-8") | |
| return b64img | |
| def do_api_requests(url=API_URL_EXTRACT, **kwargs): | |
| """ | |
| Helper function to send an API request to the image matching service. | |
| Args: | |
| url (str): The URL of the API endpoint to use. Defaults to the | |
| feature extraction endpoint. | |
| **kwargs: Additional keyword arguments to pass to the API. | |
| Returns: | |
| List[Dict[str, np.ndarray]]: A list of dictionaries containing the | |
| extracted features. The keys are "keypoints", "descriptors", and | |
| "scores", and the values are ndarrays of shape (N, 2), (N, ?), | |
| and (N,), respectively. | |
| """ | |
| # Set up the request body | |
| reqbody = { | |
| # List of image data base64 encoded | |
| "data": [], | |
| # List of maximum number of keypoints to extract from each image | |
| "max_keypoints": [100, 100], | |
| # List of timestamps for each image (not used?) | |
| "timestamps": ["0", "1"], | |
| # Whether to convert the images to grayscale | |
| "grayscale": 0, | |
| # List of image height and width | |
| "image_hw": [[640, 480], [320, 240]], | |
| # Type of feature to extract | |
| "feature_type": 0, | |
| # List of rotation angles for each image | |
| "rotates": [0.0, 0.0], | |
| # List of scale factors for each image | |
| "scales": [1.0, 1.0], | |
| # List of reference points for each image (not used) | |
| "reference_points": [[640, 480], [320, 240]], | |
| # Whether to binarize the descriptors | |
| "binarize": True, | |
| } | |
| # Update the request body with the additional keyword arguments | |
| reqbody.update(kwargs) | |
| try: | |
| # Send the request | |
| r = requests.post(url, json=reqbody) | |
| if r.status_code == 200: | |
| # Return the response | |
| return r.json() | |
| else: | |
| # Print an error message if the response code is not 200 | |
| print(f"Error: Response code {r.status_code} - {r.text}") | |
| except Exception as e: | |
| # Print an error message if an exception occurs | |
| print(f"An error occurred: {e}") | |
| def send_request_match(path0: str, path1: str) -> Dict[str, np.ndarray]: | |
| """ | |
| Send a request to the API to generate a match between two images. | |
| Args: | |
| path0 (str): The path to the first image. | |
| path1 (str): The path to the second image. | |
| Returns: | |
| Dict[str, np.ndarray]: A dictionary containing the generated matches. | |
| The keys are "keypoints0", "keypoints1", "matches0", and "matches1", | |
| and the values are ndarrays of shape (N, 2), (N, 2), (N, 2), and | |
| (N, 2), respectively. | |
| """ | |
| files = {"image0": open(path0, "rb"), "image1": open(path1, "rb")} | |
| try: | |
| # TODO: replace files with post json | |
| response = requests.post(API_URL_MATCH, files=files) | |
| pred = {} | |
| if response.status_code == 200: | |
| pred = response.json() | |
| for key in list(pred.keys()): | |
| pred[key] = np.array(pred[key]) | |
| else: | |
| print( | |
| f"Error: Response code {response.status_code} - {response.text}" | |
| ) | |
| finally: | |
| files["image0"].close() | |
| files["image1"].close() | |
| return pred | |
| def send_request_extract( | |
| input_images: str, viz: bool = False | |
| ) -> List[Dict[str, np.ndarray]]: | |
| """ | |
| Send a request to the API to extract features from an image. | |
| Args: | |
| input_images (str): The path to the image. | |
| Returns: | |
| List[Dict[str, np.ndarray]]: A list of dictionaries containing the | |
| extracted features. The keys are "keypoints", "descriptors", and | |
| "scores", and the values are ndarrays of shape (N, 2), (N, 128), | |
| and (N,), respectively. | |
| """ | |
| image_data = read_image(input_images) | |
| inputs = { | |
| "data": [image_data], | |
| } | |
| response = do_api_requests( | |
| url=API_URL_EXTRACT, | |
| **inputs, | |
| ) | |
| # breakpoint() | |
| # print("Keypoints detected: {}".format(len(response[0]["keypoints"]))) | |
| # draw matching, debug only | |
| if viz: | |
| from hloc.utils.viz import plot_keypoints | |
| from ui.viz import fig2im, plot_images | |
| kpts = np.array(response[0]["keypoints_orig"]) | |
| if "image_orig" in response[0].keys(): | |
| img_orig = np.array(["image_orig"]) | |
| output_keypoints = plot_images([img_orig], titles="titles", dpi=300) | |
| plot_keypoints([kpts]) | |
| output_keypoints = fig2im(output_keypoints) | |
| cv2.imwrite( | |
| "demo_match.jpg", | |
| output_keypoints[:, :, ::-1].copy(), # RGB -> BGR | |
| ) | |
| return response | |
| def get_api_version(): | |
| try: | |
| response = requests.get(API_VERSION).json() | |
| print("API VERSION: {}".format(response["version"])) | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Send text to stable audio server and receive generated audio." | |
| ) | |
| parser.add_argument( | |
| "--image0", | |
| required=False, | |
| help="Path for the file's melody", | |
| default="datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg", | |
| ) | |
| parser.add_argument( | |
| "--image1", | |
| required=False, | |
| help="Path for the file's melody", | |
| default="datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg", | |
| ) | |
| args = parser.parse_args() | |
| # get api version | |
| get_api_version() | |
| # request match | |
| # for i in range(10): | |
| # t1 = time.time() | |
| # preds = send_request_match(args.image0, args.image1) | |
| # t2 = time.time() | |
| # print( | |
| # "Time cost1: {} seconds, matched: {}".format( | |
| # (t2 - t1), len(preds["mmkeypoints0_orig"]) | |
| # ) | |
| # ) | |
| # request extract | |
| for i in range(1000): | |
| t1 = time.time() | |
| preds = send_request_extract(args.image0) | |
| t2 = time.time() | |
| print(f"Time cost2: {(t2 - t1)} seconds") | |
| # dump preds | |
| with open("preds.pkl", "wb") as f: | |
| pickle.dump(preds, f) | |