Hamed Mohammadpour commited on
Commit
d2c6071
·
1 Parent(s): 62857df

Add application file

Browse files
Files changed (1) hide show
  1. app.py +171 -0
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import os
4
+ import json
5
+ import time
6
+ import requests
7
+ from concurrent.futures import ThreadPoolExecutor
8
+
9
+ # Streamlit app configuration
10
+ st.set_page_config(layout="wide")
11
+
12
+ # Function to make the API call and save the response
13
+ def make_api_call(query, params, url, headers, output_dir):
14
+ data = {
15
+ "message": query["message"],
16
+ "stream": False,
17
+ "return_prompt": True,
18
+ **params
19
+ }
20
+ start_time = time.time()
21
+ response = requests.post(url, headers=headers, json=data)
22
+ end_time = time.time()
23
+ duration = end_time - start_time
24
+
25
+ query_id = query["id"]
26
+ query_dir = os.path.join(output_dir, f'query_{query_id}')
27
+ os.makedirs(query_dir, exist_ok=True)
28
+
29
+ file_name = f'response_{params["model"]}_{"multi_step" if params.get("tools") else "single_step"}.json'
30
+ file_path = os.path.join(query_dir, file_name)
31
+ with open(file_path, 'w') as f:
32
+ json.dump(response.json(), f, indent=2)
33
+
34
+ response_data = response.json()
35
+ extracted_data = {
36
+ "response_id": response_data.get("response_id"),
37
+ "text": response_data.get("text"),
38
+ "generation_id": response_data.get("generation_id"),
39
+ "finish_reason": response_data.get("finish_reason"),
40
+ "meta": {
41
+ "api_version": response_data.get("meta", {}).get("api_version", {}).get("version"),
42
+ "billed_units": {
43
+ "input_tokens": response_data.get("meta", {}).get("billed_units", {}).get("input_tokens"),
44
+ "output_tokens": response_data.get("meta", {}).get("billed_units", {}).get("output_tokens")
45
+ },
46
+ "tokens": {
47
+ "input_tokens": response_data.get("meta", {}).get("tokens", {}).get("input_tokens"),
48
+ "output_tokens": response_data.get("meta", {}).get("tokens", {}).get("output_tokens")
49
+ }
50
+ }
51
+ }
52
+
53
+ return {"query_id": query_id,
54
+ "model": params["model"],
55
+ "tools": params.get("tools", params.get("connectors")),
56
+ "enable_hosted_multi_step": params.get("enable_hosted_multi_step"),
57
+ "connectors": params.get("connectors"),
58
+ "message": query["message"],
59
+ "response": extracted_data["text"],
60
+ **extracted_data,
61
+ "params": params,
62
+ "duration": duration,}
63
+
64
+ # Streamlit UI
65
+ st.title('Cohere API Benchmarking Tool')
66
+
67
+ # Input for API key
68
+ api_key = st.text_input('Enter your Cohere API key:', type='password')
69
+
70
+ # Input for messages
71
+ messages_input = st.text_area('Enter messages to benchmark (one per line):')
72
+ messages = [{"id": idx + 1, "message": msg} for idx, msg in enumerate(messages_input.split('\n')) if msg.strip()]
73
+
74
+ # Define the combinations of parameters
75
+ param_combinations = [
76
+ {"model": "command-r", "tools": [{"name": "internet_search"}], "enable_hosted_multi_step": True},
77
+ {"model": "command-r", "connectors": [{"id": "web-search", "name": "web search"}]},
78
+ {"model": "command-r-plus", "tools": [{"name": "internet_search"}], "enable_hosted_multi_step": True},
79
+ {"model": "command-r-plus", "connectors": [{"id": "web-search", "name": "web search"}]}
80
+ ]
81
+
82
+ # Define the API endpoint and headers
83
+ url = 'https://production.api.cohere.ai/v1/chat'
84
+ headers = {
85
+ 'Content-Type': 'application/json',
86
+ 'Authorization': f'Bearer {api_key}'
87
+ }
88
+
89
+ # Create a directory to store the JSON files
90
+ output_dir = 'api_responses'
91
+ os.makedirs(output_dir, exist_ok=True)
92
+
93
+ if st.button('Run Benchmark'):
94
+ if not api_key:
95
+ st.error('API key is required.')
96
+ elif not messages:
97
+ st.error('At least one message is required.')
98
+ else:
99
+ # Create a ThreadPoolExecutor for parallel execution
100
+ with ThreadPoolExecutor() as executor:
101
+ # Submit the API calls to the executor
102
+ futures = []
103
+ for query in messages:
104
+ for params in param_combinations:
105
+ future = executor.submit(make_api_call, query, params, url, headers, output_dir)
106
+ futures.append(future)
107
+
108
+ # Collect the results from the futures
109
+ results = [future.result() for future in futures]
110
+
111
+ # Create a DataFrame from the results
112
+ df = pd.DataFrame(results)
113
+ # Save the DataFrame to a CSV file
114
+ df.to_csv('api_benchmarking_results.csv', index=False)
115
+
116
+ st.success('Benchmarking completed!')
117
+
118
+ # Display the DataFrame
119
+ st.dataframe(df)
120
+
121
+ # Offer download of the JSON files
122
+ for query in messages:
123
+ query_id = query["id"]
124
+ query_dir = os.path.join(output_dir, f'query_{query_id}')
125
+ st.markdown(f"### Query ID: {query_id}")
126
+ for file_name in os.listdir(query_dir):
127
+ file_path = os.path.join(query_dir, file_name)
128
+ with open(file_path, 'r') as f:
129
+ st.download_button(label=f"Download {file_name}", data=f.read(), file_name=file_name, mime='application/json')
130
+
131
+ # Visualization part
132
+ st.title('API Benchmarking Results')
133
+
134
+ # Group by query ID
135
+ grouped = df.groupby('query_id')
136
+
137
+ # Display each query and its data
138
+ for query_id, group in grouped:
139
+ st.header(f"Query ID: {query_id}")
140
+
141
+ # Extract billed input and output tokens from meta
142
+ group['billed_input_tokens'] = group['meta'].apply(lambda x: eval(x).get('billed_units', {}).get('input_tokens', 'N/A'))
143
+ group['billed_output_tokens'] = group['meta'].apply(lambda x: eval(x).get('billed_units', {}).get('output_tokens', 'N/A'))
144
+
145
+ # Display the runs for each query in a table
146
+ st.write(group[['message', 'response', 'billed_input_tokens', 'billed_output_tokens', 'model', 'tools', 'enable_hosted_multi_step', 'connectors', 'duration']])
147
+
148
+ # Toggle for detailed information
149
+ with st.expander("Show Details"):
150
+ for index, row in group.iterrows():
151
+ st.subheader(f"Run {index + 1}")
152
+ st.write(f"**Model:** {row['model']}")
153
+ st.write(f"**Tools:** {row['tools']}")
154
+ st.write(f"**Enable Hosted Multi-Step:** {row['enable_hosted_multi_step']}")
155
+ st.write(f"**Connectors:** {row['connectors']}")
156
+ st.write(f"**Message:** {row['message']}")
157
+ st.write(f"**Response:** {row['response']}")
158
+ st.write(f"**Response ID:** {row['response_id']}")
159
+ st.write(f"**Text:** {row['text']}")
160
+ st.write(f"**Generation ID:** {row['generation_id']}")
161
+ st.write(f"**Finish Reason:** {row['finish_reason']}")
162
+ st.write(f"**Meta:** {row['meta']}")
163
+ st.write(f"**Params:** {row['params']}")
164
+ st.write(f"**Duration:** {row['duration']}")
165
+
166
+ # Parse and display meta field
167
+ meta = eval(row['meta'])
168
+ st.write(f"API Version: {meta.get('api_version', 'N/A')}")
169
+ billed_units = meta.get('billed_units', {})
170
+ st.write(f"Billed Units - Input Tokens: {billed_units.get('input_tokens', 'N/A')}")
171
+ st.write(f"Billed Units - Output Tokens: {billed_units.get('output_tokens', 'N/A')}")