vumichien commited on
Commit
aeac4b2
·
1 Parent(s): df40595

fix raw predict

Browse files
Files changed (2) hide show
  1. models.py +6 -1
  2. routes/predict.py +45 -8
models.py CHANGED
@@ -37,9 +37,14 @@ class PredictRecord(BaseModel):
37
 
38
 
39
  class PredictResult(BaseModel):
 
 
 
 
 
 
40
  standard_subject: str
41
  standard_name: str
42
- anchor_name: str
43
  similarity_score: float
44
 
45
 
 
37
 
38
 
39
  class PredictResult(BaseModel):
40
+ subject: str
41
+ sub_subject: str
42
+ name_category: str
43
+ name: str
44
+ abstract: Optional[str] = None
45
+ memo: Optional[str] = None
46
  standard_subject: str
47
  standard_name: str
 
48
  similarity_score: float
49
 
50
 
routes/predict.py CHANGED
@@ -181,31 +181,68 @@ async def predict_raw(
181
  inputData = InputNameData(sentence_service.dic_standard_subject)
182
  # Use _add_raw_data instead of direct assignment
183
  inputData._add_raw_data(df)
184
- inputData.process_data(sentence_service.sentenceTransformerHelper)
185
  except Exception as e:
186
  print(f"Error processing input data: {e}")
187
  raise HTTPException(status_code=500, detail=str(e))
188
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  # Map standard names
190
  try:
191
  nameMapper = NameMapper(
192
  sentence_service.sentenceTransformerHelper,
193
  sentence_service.standardNameMapData,
194
- top_count=3,
195
  )
196
  df_predicted = nameMapper.predict(inputData)
197
  except Exception as e:
198
  print(f"Error mapping standard names: {e}")
 
199
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  # Convert results to response format
202
  results = []
203
- for _, row in df_predicted.iterrows():
204
  result = PredictResult(
205
- standard_subject=row["標準科目"],
206
- standard_name=row["標準項目名"],
207
- anchor_name=row["基準名称"],
208
- similarity_score=float(row["基準名称類似度"]),
 
 
 
 
 
209
  )
210
  results.append(result)
211
 
 
181
  inputData = InputNameData(sentence_service.dic_standard_subject)
182
  # Use _add_raw_data instead of direct assignment
183
  inputData._add_raw_data(df)
 
184
  except Exception as e:
185
  print(f"Error processing input data: {e}")
186
  raise HTTPException(status_code=500, detail=str(e))
187
+ try:
188
+ subject_mapper = SubjectMapper(
189
+ sentence_transformer_helper=sentence_service.sentenceTransformerHelper,
190
+ dic_subject_map=sentence_service.dic_standard_subject,
191
+ similarity_threshold=0.9,
192
+ )
193
+ dic_subject_map = subject_mapper.map_standard_subjects(inputData.dataframe)
194
+ except Exception as e:
195
+ print(f"Error processing SubjectMapper: {e}")
196
+ raise HTTPException(status_code=500, detail=str(e))
197
+ try:
198
+ inputData.dic_standard_subject = dic_subject_map
199
+ inputData.process_data()
200
+ except Exception as e:
201
+ print(f"Error processing inputData process_data: {e}")
202
+ raise HTTPException(status_code=500, detail=str(e))
203
  # Map standard names
204
  try:
205
  nameMapper = NameMapper(
206
  sentence_service.sentenceTransformerHelper,
207
  sentence_service.standardNameMapData,
208
+ top_count=3
209
  )
210
  df_predicted = nameMapper.predict(inputData)
211
  except Exception as e:
212
  print(f"Error mapping standard names: {e}")
213
+ traceback.print_exc()
214
  raise HTTPException(status_code=500, detail=str(e))
215
+
216
+ important_columns = ['確定', '標準科目', '標準項目名', '基準名称類似度']
217
+ for column in important_columns:
218
+ if column not in df_predicted.columns:
219
+ if column != '基準名称類似度':
220
+ df_predicted[column] = ""
221
+ inputData.dataframe[column] = ""
222
+ else:
223
+ df_predicted[column] = 0
224
+ inputData.dataframe[column] = 0
225
+
226
+ column_to_keep = ['シート名', '行', '科目', '中科目', '分類', '名称', '摘要', '備考', '確定']
227
+ output_df = inputData.dataframe[column_to_keep].copy()
228
+ output_df.reset_index(drop=False, inplace=True)
229
+ output_df.loc[:, "出力_科目"] = df_predicted["標準科目"]
230
+ output_df.loc[:, "出力_項目名"] = df_predicted["標準項目名"]
231
+ output_df.loc[:, "出力_確率度"] = df_predicted["基準名称類似度"]
232
 
233
  # Convert results to response format
234
  results = []
235
+ for _, row in output_df.iterrows():
236
  result = PredictResult(
237
+ subject=row["科目"],
238
+ sub_subject=row["中科目"],
239
+ name_category=row["分類"],
240
+ name=row["名称"],
241
+ abstract=row["摘要"],
242
+ memo=row["備考"],
243
+ standard_subject=row["出力_科目"],
244
+ standard_name=row["出力_項目名"],
245
+ similarity_score=float(row["出力_確率度"]),
246
  )
247
  results.append(result)
248