Tuan Tran
update helper functions
9e6ac41
raw
history blame
7.11 kB
import json
import logging
import mimetypes
import os
import typing as tp
from io import StringIO
from urllib.parse import unquote
import pandas as pd
import requests
from flask import Flask, Response, request, send_from_directory
from flask_cors import CORS
from tools import (get_leaderboard_filters, # Import your function
get_old_format_dataframe)
from backend.config import ABS_DATASET_DOMAIN, get_dataset_config, get_datasets
from backend.descriptions import (DATASET_DESCRIPTIONS, DESCRIPTIONS,
METRIC_DESCRIPTIONS, MODEL_DESCRIPTIONS)
from backend.examples import get_examples_tab
logger = logging.getLogger(__name__)
if not logger.hasHandlers():
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s"))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logger.warning("Starting the Flask app...")
app = Flask(__name__, static_folder="../frontend/dist", static_url_path="")
CORS(app)
@app.route("/")
def index():
logger.warning("Serving index.html")
return send_from_directory(app.static_folder, "index.html")
@app.route("/datasets")
def datasets():
"""
Returns the dataset configs grouped by audio / image / video.
"""
return Response(json.dumps(get_datasets()), mimetype="application/json")
@app.route("/data/<path:dataset_name>")
def data_files(dataset_name):
"""
Serves csv files from S3 or locally based on config
"""
# Get dataset_type from query params
dataset_type = request.args.get("dataset_type")
if not dataset_type:
logger.error("No dataset_type provided in query parameters.")
return "Dataset type not specified", 400
dataset_config = get_dataset_config(dataset_name)
file_path = (
os.path.join(dataset_config["path"], dataset_name) + f"_{dataset_type}.csv"
)
logger.info(f"Looking for dataset file: {file_path}")
try:
df = pd.read_csv(file_path)
logger.info(f"Processing dataset: {dataset_name}")
config = get_dataset_config(dataset_name)
if dataset_type == "benchmark":
return get_leaderboard(config, df)
elif dataset_type == "attacks_variations":
return get_chart(config, df)
except:
logger.error(f"Failed to fetch file: {file_path}")
return "File not found", 404
@app.route("/examples/<path:type>")
def example_files(type):
"""
Serve example files from S3 or locally based on config
"""
result = get_examples_tab(type)
return Response(json.dumps(result), mimetype="application/json")
@app.route("/descriptions")
def descriptions():
"""
Serve descriptions and model descriptions from descriptions.py
"""
return Response(
json.dumps(
{
"descriptions": DESCRIPTIONS,
"metric_descriptions": METRIC_DESCRIPTIONS,
"model_descriptions": MODEL_DESCRIPTIONS,
"dataset_descriptions": DATASET_DESCRIPTIONS,
}
),
mimetype="application/json",
)
# Add a proxy endpoint to bypass CORS issues
@app.route("/proxy/<path:url>")
def proxy(url):
"""
Proxy endpoint to fetch remote files and serve them to the frontend.
This helps bypass CORS restrictions on remote resources.
"""
try:
# Decode the URL parameter
url = unquote(url)
# Make sure we're only proxying from trusted domains for security
if not url.startswith(ABS_DATASET_DOMAIN):
return {"error": "Only proxying from allowed domains is permitted"}, 403
if url.startswith("http://") or url.startswith("https://"):
response = requests.get(url, stream=True)
if response.status_code != 200:
return {"error": f"Failed to fetch from {url}"}, response.status_code
# Create a Flask Response with the same content type as the original
excluded_headers = [
"content-encoding",
"content-length",
"transfer-encoding",
"connection",
]
headers = {
name: value
for name, value in response.headers.items()
if name.lower() not in excluded_headers
}
# Add CORS headers
headers["Access-Control-Allow-Origin"] = "*"
return Response(response.content, response.status_code, headers)
else:
# Serve a local file if the URL is not a network resource
local_path = url
if not os.path.exists(local_path):
return {"error": f"Local file not found: {local_path}"}, 404
with open(local_path, "rb") as f:
content = f.read()
# Guess content type based on file extension
mime_type, _ = mimetypes.guess_type(local_path)
headers = {"Access-Control-Allow-Origin": "*"}
return Response(
content,
mimetype=mime_type or "application/octet-stream",
headers=headers,
)
except Exception as e:
return {"error": str(e)}, 500
def get_leaderboard(config, df):
# Determine file type and handle accordingly
logger.warning(f"Processing dataset with config: {config}")
# This part adds on all the columns
df = get_old_format_dataframe(df, config["first_cols"], config["attack_scores"])
groups, default_selection = get_leaderboard_filters(df, config["categories"])
# Replace NaN values with None for JSON serialization
df = df.fillna(value="NaN")
# Transpose the DataFrame so each column becomes a row and column is the model
df = df.set_index("model").T.reset_index()
df = df.rename(columns={"index": "metric"})
# Convert DataFrame to JSON
result = {
"groups": {group: list(metrics) for group, metrics in groups.items()},
"default_selected_metrics": list(default_selection),
"rows": df.to_dict(orient="records"),
}
return Response(json.dumps(result), mimetype="application/json")
def get_chart(config, df):
# This function should return the chart data based on the DataFrame
# For now, we will just return a placeholder response
# Replace NaN values with None for JSON serialization
attacks_plot_metrics = [
"bit_acc",
"log10_p_value",
"TPR",
"FPR",
"watermark_det_score",
]
df = df.fillna(value="NaN")
chart_data = {
"metrics": attacks_plot_metrics,
"attacks_with_variations": config["attacks_with_variations"],
"all_attacks_df": df.to_dict(orient="records"),
}
return Response(json.dumps(chart_data), mimetype="application/json")
@app.errorhandler(404)
def not_found(e):
# Serve index.html for any 404 (SPA fallback)
return send_from_directory(app.static_folder, "index.html")
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True, use_reloader=True)