import json import logging import mimetypes import os import typing as tp from io import StringIO from urllib.parse import unquote import pandas as pd import requests from flask import Flask, Response, request, send_from_directory from flask_cors import CORS from tools import (get_leaderboard_filters, # Import your function get_old_format_dataframe) from backend.config import ABS_DATASET_DOMAIN, get_dataset_config, get_datasets from backend.descriptions import (DATASET_DESCRIPTIONS, DESCRIPTIONS, METRIC_DESCRIPTIONS, MODEL_DESCRIPTIONS) from backend.examples import get_examples_tab logger = logging.getLogger(__name__) if not logger.hasHandlers(): handler = logging.StreamHandler() handler.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s")) logger.addHandler(handler) logger.setLevel(logging.INFO) logger.warning("Starting the Flask app...") app = Flask(__name__, static_folder="../frontend/dist", static_url_path="") CORS(app) @app.route("/") def index(): logger.warning("Serving index.html") return send_from_directory(app.static_folder, "index.html") @app.route("/datasets") def datasets(): """ Returns the dataset configs grouped by audio / image / video. """ return Response(json.dumps(get_datasets()), mimetype="application/json") @app.route("/data/") def data_files(dataset_name): """ Serves csv files from S3 or locally based on config """ # Get dataset_type from query params dataset_type = request.args.get("dataset_type") if not dataset_type: logger.error("No dataset_type provided in query parameters.") return "Dataset type not specified", 400 dataset_config = get_dataset_config(dataset_name) file_path = ( os.path.join(dataset_config["path"], dataset_name) + f"_{dataset_type}.csv" ) logger.info(f"Looking for dataset file: {file_path}") try: df = pd.read_csv(file_path) logger.info(f"Processing dataset: {dataset_name}") config = get_dataset_config(dataset_name) if dataset_type == "benchmark": return get_leaderboard(config, df) elif dataset_type == "attacks_variations": return get_chart(config, df) except: logger.error(f"Failed to fetch file: {file_path}") return "File not found", 404 @app.route("/examples/") def example_files(type): """ Serve example files from S3 or locally based on config """ result = get_examples_tab(type) return Response(json.dumps(result), mimetype="application/json") @app.route("/descriptions") def descriptions(): """ Serve descriptions and model descriptions from descriptions.py """ return Response( json.dumps( { "descriptions": DESCRIPTIONS, "metric_descriptions": METRIC_DESCRIPTIONS, "model_descriptions": MODEL_DESCRIPTIONS, "dataset_descriptions": DATASET_DESCRIPTIONS, } ), mimetype="application/json", ) # Add a proxy endpoint to bypass CORS issues @app.route("/proxy/") def proxy(url): """ Proxy endpoint to fetch remote files and serve them to the frontend. This helps bypass CORS restrictions on remote resources. """ try: # Decode the URL parameter url = unquote(url) # Make sure we're only proxying from trusted domains for security if not url.startswith(ABS_DATASET_DOMAIN): return {"error": "Only proxying from allowed domains is permitted"}, 403 if url.startswith("http://") or url.startswith("https://"): response = requests.get(url, stream=True) if response.status_code != 200: return {"error": f"Failed to fetch from {url}"}, response.status_code # Create a Flask Response with the same content type as the original excluded_headers = [ "content-encoding", "content-length", "transfer-encoding", "connection", ] headers = { name: value for name, value in response.headers.items() if name.lower() not in excluded_headers } # Add CORS headers headers["Access-Control-Allow-Origin"] = "*" return Response(response.content, response.status_code, headers) else: # Serve a local file if the URL is not a network resource local_path = url if not os.path.exists(local_path): return {"error": f"Local file not found: {local_path}"}, 404 with open(local_path, "rb") as f: content = f.read() # Guess content type based on file extension mime_type, _ = mimetypes.guess_type(local_path) headers = {"Access-Control-Allow-Origin": "*"} return Response( content, mimetype=mime_type or "application/octet-stream", headers=headers, ) except Exception as e: return {"error": str(e)}, 500 def get_leaderboard(config, df): # Determine file type and handle accordingly logger.warning(f"Processing dataset with config: {config}") # This part adds on all the columns df = get_old_format_dataframe(df, config["first_cols"], config["attack_scores"]) groups, default_selection = get_leaderboard_filters(df, config["categories"]) # Replace NaN values with None for JSON serialization df = df.fillna(value="NaN") # Transpose the DataFrame so each column becomes a row and column is the model df = df.set_index("model").T.reset_index() df = df.rename(columns={"index": "metric"}) # Convert DataFrame to JSON result = { "groups": {group: list(metrics) for group, metrics in groups.items()}, "default_selected_metrics": list(default_selection), "rows": df.to_dict(orient="records"), } return Response(json.dumps(result), mimetype="application/json") def get_chart(config, df): # This function should return the chart data based on the DataFrame # For now, we will just return a placeholder response # Replace NaN values with None for JSON serialization attacks_plot_metrics = [ "bit_acc", "log10_p_value", "TPR", "FPR", "watermark_det_score", ] df = df.fillna(value="NaN") chart_data = { "metrics": attacks_plot_metrics, "attacks_with_variations": config["attacks_with_variations"], "all_attacks_df": df.to_dict(orient="records"), } return Response(json.dumps(chart_data), mimetype="application/json") @app.errorhandler(404) def not_found(e): # Serve index.html for any 404 (SPA fallback) return send_from_directory(app.static_folder, "index.html") if __name__ == "__main__": app.run(host="0.0.0.0", port=7860, debug=True, use_reloader=True)