simondh commited on
Commit
8bef8d4
·
1 Parent(s): 156898c

new endpoint

Browse files
Files changed (3) hide show
  1. process.py +1 -1
  2. server.py +47 -0
  3. test_server.py +47 -1
process.py CHANGED
@@ -223,7 +223,7 @@ async def improve_classification(
223
  response = await asyncio.get_event_loop().run_in_executor(
224
  None,
225
  lambda: client.chat.completions.create(
226
- model="gpt-4",
227
  messages=[{"role": "user", "content": prompt}],
228
  temperature=0,
229
  max_tokens=300,
 
223
  response = await asyncio.get_event_loop().run_in_executor(
224
  None,
225
  lambda: client.chat.completions.create(
226
+ model="gpt-3.5-turbo",
227
  messages=[{"role": "user", "content": prompt}],
228
  temperature=0,
229
  max_tokens=300,
server.py CHANGED
@@ -11,6 +11,7 @@ import os
11
  from dotenv import load_dotenv
12
  import pandas as pd
13
  from utils import validate_results
 
14
 
15
  # Load environment variables
16
  load_dotenv()
@@ -88,6 +89,21 @@ class ValidationResponse(BaseModel):
88
  misclassifications: Optional[List[Dict[str, Any]]] = None
89
  suggested_improvements: Optional[List[str]] = None
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  @app.get("/health", response_model=HealthResponse)
92
  async def health_check() -> HealthResponse:
93
  """Check the health status of the API"""
@@ -208,6 +224,37 @@ async def validate_classifications(validation_request: ValidationRequest) -> Val
208
  except Exception as e:
209
  raise HTTPException(status_code=500, detail=str(e))
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  if __name__ == "__main__":
212
  import uvicorn
213
  uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)
 
11
  from dotenv import load_dotenv
12
  import pandas as pd
13
  from utils import validate_results
14
+ from process import improve_classification
15
 
16
  # Load environment variables
17
  load_dotenv()
 
89
  misclassifications: Optional[List[Dict[str, Any]]] = None
90
  suggested_improvements: Optional[List[str]] = None
91
 
92
+ class ImprovementRequest(BaseModel):
93
+ df: Dict[str, Any] # JSON representation of the DataFrame
94
+ validation_report: str
95
+ text_columns: List[str]
96
+ categories: str
97
+ classifier_type: str
98
+ show_explanations: bool
99
+ file_path: str
100
+
101
+ class ImprovementResponse(BaseModel):
102
+ improved_df: Dict[str, Any] # JSON representation of the improved DataFrame
103
+ new_validation_report: str
104
+ success: bool
105
+ updated_categories: List[str]
106
+
107
  @app.get("/health", response_model=HealthResponse)
108
  async def health_check() -> HealthResponse:
109
  """Check the health status of the API"""
 
224
  except Exception as e:
225
  raise HTTPException(status_code=500, detail=str(e))
226
 
227
+ @app.post("/improve-classification", response_model=ImprovementResponse)
228
+ async def improve_classification_endpoint(request: ImprovementRequest) -> ImprovementResponse:
229
+ """Improve classification based on validation report"""
230
+ try:
231
+ # Convert JSON DataFrame back to pandas DataFrame
232
+ df = pd.DataFrame.from_dict(request.df)
233
+
234
+ # Call the improve_classification function
235
+ improved_df, new_validation, success, updated_categories = await improve_classification(
236
+ df=df,
237
+ validation_report=request.validation_report,
238
+ text_columns=request.text_columns,
239
+ categories=request.categories,
240
+ classifier_type=request.classifier_type,
241
+ show_explanations=request.show_explanations,
242
+ file=request.file_path
243
+ )
244
+
245
+ # Convert improved DataFrame to JSON
246
+ improved_df_json = improved_df.to_dict() if improved_df is not None else None
247
+
248
+ return ImprovementResponse(
249
+ improved_df=improved_df_json,
250
+ new_validation_report=new_validation,
251
+ success=success,
252
+ updated_categories=updated_categories
253
+ )
254
+
255
+ except Exception as e:
256
+ raise HTTPException(status_code=500, detail=str(e))
257
+
258
  if __name__ == "__main__":
259
  import uvicorn
260
  uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)
test_server.py CHANGED
@@ -1,6 +1,7 @@
1
  import requests
2
  import json
3
  from typing import List, Dict, Any, Optional
 
4
 
5
  BASE_URL: str = "http://localhost:8000"
6
 
@@ -123,6 +124,50 @@ def test_validate_classifications() -> None:
123
  )
124
  print("\nValidation results:")
125
  print(json.dumps(response.json(), indent=2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  if __name__ == "__main__":
128
  print("Testing FastAPI server endpoints...")
@@ -131,4 +176,5 @@ if __name__ == "__main__":
131
  test_classify_text()
132
  test_classify_batch()
133
  test_suggest_categories()
134
- test_validate_classifications()
 
 
1
  import requests
2
  import json
3
  from typing import List, Dict, Any, Optional
4
+ import pandas as pd
5
 
6
  BASE_URL: str = "http://localhost:8000"
7
 
 
124
  )
125
  print("\nValidation results:")
126
  print(json.dumps(response.json(), indent=2))
127
+ return response.json() # Return validation results for use in improve test
128
+
129
+ def test_improve_classification() -> None:
130
+ """Test the improve-classification endpoint"""
131
+ # First get validation results
132
+ validation_results = test_validate_classifications()
133
+
134
+ # Load emails from CSV file
135
+ import csv
136
+
137
+ emails: List[Dict[str, str]] = []
138
+ with open("examples/emails.csv", "r", encoding="utf-8") as file:
139
+ reader = csv.DictReader(file)
140
+ for row in reader:
141
+ emails.append(row)
142
+
143
+ # Create a DataFrame with the first 5 emails
144
+ df = pd.DataFrame(emails[:5])
145
+
146
+ # Get current categories
147
+ categories_response: requests.Response = requests.post(
148
+ f"{BASE_URL}/suggest-categories",
149
+ json=[email["contenu"] for email in emails[:5]]
150
+ )
151
+ response_data: Dict[str, Any] = categories_response.json()
152
+ current_categories: str = ",".join(response_data["categories"])
153
+
154
+ # Send improvement request
155
+ improvement_request: Dict[str, Any] = {
156
+ "df": df.to_dict(),
157
+ "validation_report": validation_results["validation_report"],
158
+ "text_columns": ["contenu"],
159
+ "categories": current_categories,
160
+ "classifier_type": "gpt35",
161
+ "show_explanations": True,
162
+ "file_path": "examples/emails.csv"
163
+ }
164
+
165
+ response: requests.Response = requests.post(
166
+ f"{BASE_URL}/improve-classification",
167
+ json=improvement_request
168
+ )
169
+ print("\nImprovement results:")
170
+ print(json.dumps(response.json(), indent=2))
171
 
172
  if __name__ == "__main__":
173
  print("Testing FastAPI server endpoints...")
 
176
  test_classify_text()
177
  test_classify_batch()
178
  test_suggest_categories()
179
+ test_validate_classifications()
180
+ test_improve_classification()