Hamed Mohammadpour
Add unique id for each button
1ea4886
import streamlit as st
import pandas as pd
import os
import json
import time
import requests
from concurrent.futures import ThreadPoolExecutor
# Streamlit app configuration
st.set_page_config(layout="wide")
# Function to make the API call and save the response
def make_api_call(query, params, url, headers, output_dir):
data = {
"message": query["message"],
"stream": False,
"return_prompt": True,
**params,
}
start_time = time.time()
response = requests.post(url, headers=headers, json=data)
end_time = time.time()
duration = end_time - start_time
query_id = query["id"]
query_dir = os.path.join(output_dir, f"query_{query_id}")
os.makedirs(query_dir, exist_ok=True)
file_name = f'response_{params["model"]}_{"multi_step" if params.get("tools") else "single_step"}.json'
file_path = os.path.join(query_dir, file_name)
with open(file_path, "w") as f:
json.dump(response.json(), f, indent=2)
response_data = response.json()
extracted_data = {
"response_id": response_data.get("response_id"),
"text": response_data.get("text"),
"generation_id": response_data.get("generation_id"),
"finish_reason": response_data.get("finish_reason"),
"meta": {
"api_version": response_data.get("meta", {})
.get("api_version", {})
.get("version"),
"billed_units": {
"input_tokens": response_data.get("meta", {})
.get("billed_units", {})
.get("input_tokens"),
"output_tokens": response_data.get("meta", {})
.get("billed_units", {})
.get("output_tokens"),
},
"tokens": {
"input_tokens": response_data.get("meta", {})
.get("tokens", {})
.get("input_tokens"),
"output_tokens": response_data.get("meta", {})
.get("tokens", {})
.get("output_tokens"),
},
},
}
return {
"query_id": query_id,
"model": params["model"],
"tools": params.get("tools", params.get("connectors")),
"enable_hosted_multi_step": params.get("enable_hosted_multi_step"),
"connectors": params.get("connectors"),
"message": query["message"],
"response": extracted_data["text"],
**extracted_data,
"params": params,
"duration": duration,
}
# Streamlit UI
st.title("Cohere API Benchmarking Tool")
# Input for API key
api_key = st.text_input("Enter your Cohere API key:", type="password")
# Input for messages
messages_input = st.text_area("Enter messages to benchmark (one per line):")
messages = [
{"id": idx + 1, "message": msg}
for idx, msg in enumerate(messages_input.split("\n"))
if msg.strip()
]
# Define the combinations of parameters
param_combinations = [
{
"model": "command-r",
"tools": [{"name": "internet_search"}],
"enable_hosted_multi_step": True,
},
{"model": "command-r", "connectors": [{"id": "web-search", "name": "web search"}]},
{
"model": "command-r-plus",
"tools": [{"name": "internet_search"}],
"enable_hosted_multi_step": True,
},
{
"model": "command-r-plus",
"connectors": [{"id": "web-search", "name": "web search"}],
},
]
# Define the API endpoint and headers
url = "https://production.api.cohere.ai/v1/chat"
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
# Create a directory to store the JSON files
output_dir = "api_responses"
os.makedirs(output_dir, exist_ok=True)
if st.button("Run Benchmark"):
if not api_key:
st.error("API key is required.")
elif not messages:
st.error("At least one message is required.")
else:
# Create a ThreadPoolExecutor for parallel execution
with ThreadPoolExecutor() as executor:
# Submit the API calls to the executor
futures = []
for query in messages:
for params in param_combinations:
future = executor.submit(
make_api_call, query, params, url, headers, output_dir
)
futures.append(future)
# Collect the results from the futures
results = [future.result() for future in futures]
# Create a DataFrame from the results
df = pd.DataFrame(results)
# Save the DataFrame to a CSV file
df.to_csv("api_benchmarking_results.csv", index=False)
st.success("Benchmarking completed!")
# Display the DataFrame
st.dataframe(df)
# Offer download of the JSON files
for query in messages:
query_id = query["id"]
query_dir = os.path.join(output_dir, f"query_{query_id}")
st.markdown(f"### Query ID: {query_id}")
for file_name in os.listdir(query_dir):
file_path = os.path.join(query_dir, file_name)
with open(file_path, "r") as f:
st.download_button(
label=f"Download {file_name}",
data=f.read(),
file_name=file_name,
mime="application/json",
key=f"{query_id}_{file_name}",
)
# Visualization part
st.title("Cohere Multi tool use - Benchmarking Results")
# Group by query ID
grouped = df.groupby("query_id")
# Display each query and its data
for query_id, group in grouped:
st.header(f"Query ID: {query_id}")
# Extract billed input and output tokens from meta
group["billed_input_tokens"] = group["meta"].apply(
lambda x: x.get("billed_units", {}).get("input_tokens", "N/A")
)
group["billed_output_tokens"] = group["meta"].apply(
lambda x: x.get("billed_units", {}).get("output_tokens", "N/A")
)
# Display the runs for each query in a table
st.write(
group[
[
"message",
"response",
"billed_input_tokens",
"billed_output_tokens",
"model",
"tools",
"enable_hosted_multi_step",
"connectors",
"duration",
]
]
)
# Toggle for detailed information
with st.expander("Show Details"):
for index, row in group.iterrows():
st.subheader(f"Run {index + 1}")
st.write(f"**Model:** {row['model']}")
st.write(f"**Tools:** {row['tools']}")
st.write(
f"**Enable Hosted Multi-Step:** {row['enable_hosted_multi_step']}"
)
st.write(f"**Connectors:** {row['connectors']}")
st.write(f"**Message:** {row['message']}")
st.write(f"**Response:** {row['response']}")
st.write(f"**Response ID:** {row['response_id']}")
st.write(f"**Text:** {row['text']}")
st.write(f"**Generation ID:** {row['generation_id']}")
st.write(f"**Finish Reason:** {row['finish_reason']}")
st.write(f"**Meta:** {row['meta']}")
st.write(f"**Params:** {row['params']}")
st.write(f"**Duration:** {row['duration']}")
# Parse and display meta field
# Parse and display meta field
meta = row["meta"]
st.write(f"**API Version:** {meta.get('api_version', 'N/A')}")
billed_units = meta.get("billed_units", {})
st.write(
f"**Billed Units - Input Tokens:** {billed_units.get('input_tokens', 'N/A')}"
)
st.write(
f"**Billed Units - Output Tokens:** {billed_units.get('output_tokens', 'N/A')}"
)