Spaces:
Sleeping
Sleeping
Hamed Mohammadpour
commited on
Commit
·
d2c6071
1
Parent(s):
62857df
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
import requests
|
7 |
+
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
|
9 |
+
# Streamlit app configuration
|
10 |
+
st.set_page_config(layout="wide")
|
11 |
+
|
12 |
+
# Function to make the API call and save the response
|
13 |
+
def make_api_call(query, params, url, headers, output_dir):
|
14 |
+
data = {
|
15 |
+
"message": query["message"],
|
16 |
+
"stream": False,
|
17 |
+
"return_prompt": True,
|
18 |
+
**params
|
19 |
+
}
|
20 |
+
start_time = time.time()
|
21 |
+
response = requests.post(url, headers=headers, json=data)
|
22 |
+
end_time = time.time()
|
23 |
+
duration = end_time - start_time
|
24 |
+
|
25 |
+
query_id = query["id"]
|
26 |
+
query_dir = os.path.join(output_dir, f'query_{query_id}')
|
27 |
+
os.makedirs(query_dir, exist_ok=True)
|
28 |
+
|
29 |
+
file_name = f'response_{params["model"]}_{"multi_step" if params.get("tools") else "single_step"}.json'
|
30 |
+
file_path = os.path.join(query_dir, file_name)
|
31 |
+
with open(file_path, 'w') as f:
|
32 |
+
json.dump(response.json(), f, indent=2)
|
33 |
+
|
34 |
+
response_data = response.json()
|
35 |
+
extracted_data = {
|
36 |
+
"response_id": response_data.get("response_id"),
|
37 |
+
"text": response_data.get("text"),
|
38 |
+
"generation_id": response_data.get("generation_id"),
|
39 |
+
"finish_reason": response_data.get("finish_reason"),
|
40 |
+
"meta": {
|
41 |
+
"api_version": response_data.get("meta", {}).get("api_version", {}).get("version"),
|
42 |
+
"billed_units": {
|
43 |
+
"input_tokens": response_data.get("meta", {}).get("billed_units", {}).get("input_tokens"),
|
44 |
+
"output_tokens": response_data.get("meta", {}).get("billed_units", {}).get("output_tokens")
|
45 |
+
},
|
46 |
+
"tokens": {
|
47 |
+
"input_tokens": response_data.get("meta", {}).get("tokens", {}).get("input_tokens"),
|
48 |
+
"output_tokens": response_data.get("meta", {}).get("tokens", {}).get("output_tokens")
|
49 |
+
}
|
50 |
+
}
|
51 |
+
}
|
52 |
+
|
53 |
+
return {"query_id": query_id,
|
54 |
+
"model": params["model"],
|
55 |
+
"tools": params.get("tools", params.get("connectors")),
|
56 |
+
"enable_hosted_multi_step": params.get("enable_hosted_multi_step"),
|
57 |
+
"connectors": params.get("connectors"),
|
58 |
+
"message": query["message"],
|
59 |
+
"response": extracted_data["text"],
|
60 |
+
**extracted_data,
|
61 |
+
"params": params,
|
62 |
+
"duration": duration,}
|
63 |
+
|
64 |
+
# Streamlit UI
|
65 |
+
st.title('Cohere API Benchmarking Tool')
|
66 |
+
|
67 |
+
# Input for API key
|
68 |
+
api_key = st.text_input('Enter your Cohere API key:', type='password')
|
69 |
+
|
70 |
+
# Input for messages
|
71 |
+
messages_input = st.text_area('Enter messages to benchmark (one per line):')
|
72 |
+
messages = [{"id": idx + 1, "message": msg} for idx, msg in enumerate(messages_input.split('\n')) if msg.strip()]
|
73 |
+
|
74 |
+
# Define the combinations of parameters
|
75 |
+
param_combinations = [
|
76 |
+
{"model": "command-r", "tools": [{"name": "internet_search"}], "enable_hosted_multi_step": True},
|
77 |
+
{"model": "command-r", "connectors": [{"id": "web-search", "name": "web search"}]},
|
78 |
+
{"model": "command-r-plus", "tools": [{"name": "internet_search"}], "enable_hosted_multi_step": True},
|
79 |
+
{"model": "command-r-plus", "connectors": [{"id": "web-search", "name": "web search"}]}
|
80 |
+
]
|
81 |
+
|
82 |
+
# Define the API endpoint and headers
|
83 |
+
url = 'https://production.api.cohere.ai/v1/chat'
|
84 |
+
headers = {
|
85 |
+
'Content-Type': 'application/json',
|
86 |
+
'Authorization': f'Bearer {api_key}'
|
87 |
+
}
|
88 |
+
|
89 |
+
# Create a directory to store the JSON files
|
90 |
+
output_dir = 'api_responses'
|
91 |
+
os.makedirs(output_dir, exist_ok=True)
|
92 |
+
|
93 |
+
if st.button('Run Benchmark'):
|
94 |
+
if not api_key:
|
95 |
+
st.error('API key is required.')
|
96 |
+
elif not messages:
|
97 |
+
st.error('At least one message is required.')
|
98 |
+
else:
|
99 |
+
# Create a ThreadPoolExecutor for parallel execution
|
100 |
+
with ThreadPoolExecutor() as executor:
|
101 |
+
# Submit the API calls to the executor
|
102 |
+
futures = []
|
103 |
+
for query in messages:
|
104 |
+
for params in param_combinations:
|
105 |
+
future = executor.submit(make_api_call, query, params, url, headers, output_dir)
|
106 |
+
futures.append(future)
|
107 |
+
|
108 |
+
# Collect the results from the futures
|
109 |
+
results = [future.result() for future in futures]
|
110 |
+
|
111 |
+
# Create a DataFrame from the results
|
112 |
+
df = pd.DataFrame(results)
|
113 |
+
# Save the DataFrame to a CSV file
|
114 |
+
df.to_csv('api_benchmarking_results.csv', index=False)
|
115 |
+
|
116 |
+
st.success('Benchmarking completed!')
|
117 |
+
|
118 |
+
# Display the DataFrame
|
119 |
+
st.dataframe(df)
|
120 |
+
|
121 |
+
# Offer download of the JSON files
|
122 |
+
for query in messages:
|
123 |
+
query_id = query["id"]
|
124 |
+
query_dir = os.path.join(output_dir, f'query_{query_id}')
|
125 |
+
st.markdown(f"### Query ID: {query_id}")
|
126 |
+
for file_name in os.listdir(query_dir):
|
127 |
+
file_path = os.path.join(query_dir, file_name)
|
128 |
+
with open(file_path, 'r') as f:
|
129 |
+
st.download_button(label=f"Download {file_name}", data=f.read(), file_name=file_name, mime='application/json')
|
130 |
+
|
131 |
+
# Visualization part
|
132 |
+
st.title('API Benchmarking Results')
|
133 |
+
|
134 |
+
# Group by query ID
|
135 |
+
grouped = df.groupby('query_id')
|
136 |
+
|
137 |
+
# Display each query and its data
|
138 |
+
for query_id, group in grouped:
|
139 |
+
st.header(f"Query ID: {query_id}")
|
140 |
+
|
141 |
+
# Extract billed input and output tokens from meta
|
142 |
+
group['billed_input_tokens'] = group['meta'].apply(lambda x: eval(x).get('billed_units', {}).get('input_tokens', 'N/A'))
|
143 |
+
group['billed_output_tokens'] = group['meta'].apply(lambda x: eval(x).get('billed_units', {}).get('output_tokens', 'N/A'))
|
144 |
+
|
145 |
+
# Display the runs for each query in a table
|
146 |
+
st.write(group[['message', 'response', 'billed_input_tokens', 'billed_output_tokens', 'model', 'tools', 'enable_hosted_multi_step', 'connectors', 'duration']])
|
147 |
+
|
148 |
+
# Toggle for detailed information
|
149 |
+
with st.expander("Show Details"):
|
150 |
+
for index, row in group.iterrows():
|
151 |
+
st.subheader(f"Run {index + 1}")
|
152 |
+
st.write(f"**Model:** {row['model']}")
|
153 |
+
st.write(f"**Tools:** {row['tools']}")
|
154 |
+
st.write(f"**Enable Hosted Multi-Step:** {row['enable_hosted_multi_step']}")
|
155 |
+
st.write(f"**Connectors:** {row['connectors']}")
|
156 |
+
st.write(f"**Message:** {row['message']}")
|
157 |
+
st.write(f"**Response:** {row['response']}")
|
158 |
+
st.write(f"**Response ID:** {row['response_id']}")
|
159 |
+
st.write(f"**Text:** {row['text']}")
|
160 |
+
st.write(f"**Generation ID:** {row['generation_id']}")
|
161 |
+
st.write(f"**Finish Reason:** {row['finish_reason']}")
|
162 |
+
st.write(f"**Meta:** {row['meta']}")
|
163 |
+
st.write(f"**Params:** {row['params']}")
|
164 |
+
st.write(f"**Duration:** {row['duration']}")
|
165 |
+
|
166 |
+
# Parse and display meta field
|
167 |
+
meta = eval(row['meta'])
|
168 |
+
st.write(f"API Version: {meta.get('api_version', 'N/A')}")
|
169 |
+
billed_units = meta.get('billed_units', {})
|
170 |
+
st.write(f"Billed Units - Input Tokens: {billed_units.get('input_tokens', 'N/A')}")
|
171 |
+
st.write(f"Billed Units - Output Tokens: {billed_units.get('output_tokens', 'N/A')}")
|