import streamlit as st import pandas as pd import os import json import time import requests from concurrent.futures import ThreadPoolExecutor # Streamlit app configuration st.set_page_config(layout="wide") # Function to make the API call and save the response def make_api_call(query, params, url, headers, output_dir): data = { "message": query["message"], "stream": False, "return_prompt": True, **params, } start_time = time.time() response = requests.post(url, headers=headers, json=data) end_time = time.time() duration = end_time - start_time query_id = query["id"] query_dir = os.path.join(output_dir, f"query_{query_id}") os.makedirs(query_dir, exist_ok=True) file_name = f'response_{params["model"]}_{"multi_step" if params.get("tools") else "single_step"}.json' file_path = os.path.join(query_dir, file_name) with open(file_path, "w") as f: json.dump(response.json(), f, indent=2) response_data = response.json() extracted_data = { "response_id": response_data.get("response_id"), "text": response_data.get("text"), "generation_id": response_data.get("generation_id"), "finish_reason": response_data.get("finish_reason"), "meta": { "api_version": response_data.get("meta", {}) .get("api_version", {}) .get("version"), "billed_units": { "input_tokens": response_data.get("meta", {}) .get("billed_units", {}) .get("input_tokens"), "output_tokens": response_data.get("meta", {}) .get("billed_units", {}) .get("output_tokens"), }, "tokens": { "input_tokens": response_data.get("meta", {}) .get("tokens", {}) .get("input_tokens"), "output_tokens": response_data.get("meta", {}) .get("tokens", {}) .get("output_tokens"), }, }, } return { "query_id": query_id, "model": params["model"], "tools": params.get("tools", params.get("connectors")), "enable_hosted_multi_step": params.get("enable_hosted_multi_step"), "connectors": params.get("connectors"), "message": query["message"], "response": extracted_data["text"], **extracted_data, "params": params, "duration": duration, } # Streamlit UI st.title("Cohere API Benchmarking Tool") # Input for API key api_key = st.text_input("Enter your Cohere API key:", type="password") # Input for messages messages_input = st.text_area("Enter messages to benchmark (one per line):") messages = [ {"id": idx + 1, "message": msg} for idx, msg in enumerate(messages_input.split("\n")) if msg.strip() ] # Define the combinations of parameters param_combinations = [ { "model": "command-r", "tools": [{"name": "internet_search"}], "enable_hosted_multi_step": True, }, {"model": "command-r", "connectors": [{"id": "web-search", "name": "web search"}]}, { "model": "command-r-plus", "tools": [{"name": "internet_search"}], "enable_hosted_multi_step": True, }, { "model": "command-r-plus", "connectors": [{"id": "web-search", "name": "web search"}], }, ] # Define the API endpoint and headers url = "https://production.api.cohere.ai/v1/chat" headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} # Create a directory to store the JSON files output_dir = "api_responses" os.makedirs(output_dir, exist_ok=True) if st.button("Run Benchmark"): if not api_key: st.error("API key is required.") elif not messages: st.error("At least one message is required.") else: # Create a ThreadPoolExecutor for parallel execution with ThreadPoolExecutor() as executor: # Submit the API calls to the executor futures = [] for query in messages: for params in param_combinations: future = executor.submit( make_api_call, query, params, url, headers, output_dir ) futures.append(future) # Collect the results from the futures results = [future.result() for future in futures] # Create a DataFrame from the results df = pd.DataFrame(results) # Save the DataFrame to a CSV file df.to_csv("api_benchmarking_results.csv", index=False) st.success("Benchmarking completed!") # Display the DataFrame st.dataframe(df) # Offer download of the JSON files for query in messages: query_id = query["id"] query_dir = os.path.join(output_dir, f"query_{query_id}") st.markdown(f"### Query ID: {query_id}") for file_name in os.listdir(query_dir): file_path = os.path.join(query_dir, file_name) with open(file_path, "r") as f: st.download_button( label=f"Download {file_name}", data=f.read(), file_name=file_name, mime="application/json", key=f"{query_id}_{file_name}", ) # Visualization part st.title("Cohere Multi tool use - Benchmarking Results") # Group by query ID grouped = df.groupby("query_id") # Display each query and its data for query_id, group in grouped: st.header(f"Query ID: {query_id}") # Extract billed input and output tokens from meta group["billed_input_tokens"] = group["meta"].apply( lambda x: x.get("billed_units", {}).get("input_tokens", "N/A") ) group["billed_output_tokens"] = group["meta"].apply( lambda x: x.get("billed_units", {}).get("output_tokens", "N/A") ) # Display the runs for each query in a table st.write( group[ [ "message", "response", "billed_input_tokens", "billed_output_tokens", "model", "tools", "enable_hosted_multi_step", "connectors", "duration", ] ] ) # Toggle for detailed information with st.expander("Show Details"): for index, row in group.iterrows(): st.subheader(f"Run {index + 1}") st.write(f"**Model:** {row['model']}") st.write(f"**Tools:** {row['tools']}") st.write( f"**Enable Hosted Multi-Step:** {row['enable_hosted_multi_step']}" ) st.write(f"**Connectors:** {row['connectors']}") st.write(f"**Message:** {row['message']}") st.write(f"**Response:** {row['response']}") st.write(f"**Response ID:** {row['response_id']}") st.write(f"**Text:** {row['text']}") st.write(f"**Generation ID:** {row['generation_id']}") st.write(f"**Finish Reason:** {row['finish_reason']}") st.write(f"**Meta:** {row['meta']}") st.write(f"**Params:** {row['params']}") st.write(f"**Duration:** {row['duration']}") # Parse and display meta field # Parse and display meta field meta = row["meta"] st.write(f"**API Version:** {meta.get('api_version', 'N/A')}") billed_units = meta.get("billed_units", {}) st.write( f"**Billed Units - Input Tokens:** {billed_units.get('input_tokens', 'N/A')}" ) st.write( f"**Billed Units - Output Tokens:** {billed_units.get('output_tokens', 'N/A')}" )