Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import os | |
import json | |
import time | |
import requests | |
from concurrent.futures import ThreadPoolExecutor | |
# Streamlit app configuration | |
st.set_page_config(layout="wide") | |
# Function to make the API call and save the response | |
def make_api_call(query, params, url, headers, output_dir): | |
data = { | |
"message": query["message"], | |
"stream": False, | |
"return_prompt": True, | |
**params, | |
} | |
start_time = time.time() | |
response = requests.post(url, headers=headers, json=data) | |
end_time = time.time() | |
duration = end_time - start_time | |
query_id = query["id"] | |
query_dir = os.path.join(output_dir, f"query_{query_id}") | |
os.makedirs(query_dir, exist_ok=True) | |
file_name = f'response_{params["model"]}_{"multi_step" if params.get("tools") else "single_step"}.json' | |
file_path = os.path.join(query_dir, file_name) | |
with open(file_path, "w") as f: | |
json.dump(response.json(), f, indent=2) | |
response_data = response.json() | |
extracted_data = { | |
"response_id": response_data.get("response_id"), | |
"text": response_data.get("text"), | |
"generation_id": response_data.get("generation_id"), | |
"finish_reason": response_data.get("finish_reason"), | |
"meta": { | |
"api_version": response_data.get("meta", {}) | |
.get("api_version", {}) | |
.get("version"), | |
"billed_units": { | |
"input_tokens": response_data.get("meta", {}) | |
.get("billed_units", {}) | |
.get("input_tokens"), | |
"output_tokens": response_data.get("meta", {}) | |
.get("billed_units", {}) | |
.get("output_tokens"), | |
}, | |
"tokens": { | |
"input_tokens": response_data.get("meta", {}) | |
.get("tokens", {}) | |
.get("input_tokens"), | |
"output_tokens": response_data.get("meta", {}) | |
.get("tokens", {}) | |
.get("output_tokens"), | |
}, | |
}, | |
} | |
return { | |
"query_id": query_id, | |
"model": params["model"], | |
"tools": params.get("tools", params.get("connectors")), | |
"enable_hosted_multi_step": params.get("enable_hosted_multi_step"), | |
"connectors": params.get("connectors"), | |
"message": query["message"], | |
"response": extracted_data["text"], | |
**extracted_data, | |
"params": params, | |
"duration": duration, | |
} | |
# Streamlit UI | |
st.title("Cohere API Benchmarking Tool") | |
# Input for API key | |
api_key = st.text_input("Enter your Cohere API key:", type="password") | |
# Input for messages | |
messages_input = st.text_area("Enter messages to benchmark (one per line):") | |
messages = [ | |
{"id": idx + 1, "message": msg} | |
for idx, msg in enumerate(messages_input.split("\n")) | |
if msg.strip() | |
] | |
# Define the combinations of parameters | |
param_combinations = [ | |
{ | |
"model": "command-r", | |
"tools": [{"name": "internet_search"}], | |
"enable_hosted_multi_step": True, | |
}, | |
{"model": "command-r", "connectors": [{"id": "web-search", "name": "web search"}]}, | |
{ | |
"model": "command-r-plus", | |
"tools": [{"name": "internet_search"}], | |
"enable_hosted_multi_step": True, | |
}, | |
{ | |
"model": "command-r-plus", | |
"connectors": [{"id": "web-search", "name": "web search"}], | |
}, | |
] | |
# Define the API endpoint and headers | |
url = "https://production.api.cohere.ai/v1/chat" | |
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} | |
# Create a directory to store the JSON files | |
output_dir = "api_responses" | |
os.makedirs(output_dir, exist_ok=True) | |
if st.button("Run Benchmark"): | |
if not api_key: | |
st.error("API key is required.") | |
elif not messages: | |
st.error("At least one message is required.") | |
else: | |
# Create a ThreadPoolExecutor for parallel execution | |
with ThreadPoolExecutor() as executor: | |
# Submit the API calls to the executor | |
futures = [] | |
for query in messages: | |
for params in param_combinations: | |
future = executor.submit( | |
make_api_call, query, params, url, headers, output_dir | |
) | |
futures.append(future) | |
# Collect the results from the futures | |
results = [future.result() for future in futures] | |
# Create a DataFrame from the results | |
df = pd.DataFrame(results) | |
# Save the DataFrame to a CSV file | |
df.to_csv("api_benchmarking_results.csv", index=False) | |
st.success("Benchmarking completed!") | |
# Display the DataFrame | |
st.dataframe(df) | |
# Offer download of the JSON files | |
for query in messages: | |
query_id = query["id"] | |
query_dir = os.path.join(output_dir, f"query_{query_id}") | |
st.markdown(f"### Query ID: {query_id}") | |
for file_name in os.listdir(query_dir): | |
file_path = os.path.join(query_dir, file_name) | |
with open(file_path, "r") as f: | |
st.download_button( | |
label=f"Download {file_name}", | |
data=f.read(), | |
file_name=file_name, | |
mime="application/json", | |
key=f"{query_id}_{file_name}", | |
) | |
# Visualization part | |
st.title("Cohere Multi tool use - Benchmarking Results") | |
# Group by query ID | |
grouped = df.groupby("query_id") | |
# Display each query and its data | |
for query_id, group in grouped: | |
st.header(f"Query ID: {query_id}") | |
# Extract billed input and output tokens from meta | |
group["billed_input_tokens"] = group["meta"].apply( | |
lambda x: x.get("billed_units", {}).get("input_tokens", "N/A") | |
) | |
group["billed_output_tokens"] = group["meta"].apply( | |
lambda x: x.get("billed_units", {}).get("output_tokens", "N/A") | |
) | |
# Display the runs for each query in a table | |
st.write( | |
group[ | |
[ | |
"message", | |
"response", | |
"billed_input_tokens", | |
"billed_output_tokens", | |
"model", | |
"tools", | |
"enable_hosted_multi_step", | |
"connectors", | |
"duration", | |
] | |
] | |
) | |
# Toggle for detailed information | |
with st.expander("Show Details"): | |
for index, row in group.iterrows(): | |
st.subheader(f"Run {index + 1}") | |
st.write(f"**Model:** {row['model']}") | |
st.write(f"**Tools:** {row['tools']}") | |
st.write( | |
f"**Enable Hosted Multi-Step:** {row['enable_hosted_multi_step']}" | |
) | |
st.write(f"**Connectors:** {row['connectors']}") | |
st.write(f"**Message:** {row['message']}") | |
st.write(f"**Response:** {row['response']}") | |
st.write(f"**Response ID:** {row['response_id']}") | |
st.write(f"**Text:** {row['text']}") | |
st.write(f"**Generation ID:** {row['generation_id']}") | |
st.write(f"**Finish Reason:** {row['finish_reason']}") | |
st.write(f"**Meta:** {row['meta']}") | |
st.write(f"**Params:** {row['params']}") | |
st.write(f"**Duration:** {row['duration']}") | |
# Parse and display meta field | |
# Parse and display meta field | |
meta = row["meta"] | |
st.write(f"**API Version:** {meta.get('api_version', 'N/A')}") | |
billed_units = meta.get("billed_units", {}) | |
st.write( | |
f"**Billed Units - Input Tokens:** {billed_units.get('input_tokens', 'N/A')}" | |
) | |
st.write( | |
f"**Billed Units - Output Tokens:** {billed_units.get('output_tokens', 'N/A')}" | |
) | |