vumichien commited on
Commit
b77c0a2
·
1 Parent(s): dc7fc97

update project structure

Browse files
auth.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta, timezone
2
+ import jwt
3
+ from fastapi import Depends, HTTPException, status
4
+ from fastapi.security import OAuth2PasswordBearer
5
+ from passlib.context import CryptContext
6
+ from config import SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_HOURS
7
+ from models import TokenData, UserInDB, User
8
+ from database import users_db
9
+ from typing import Annotated, Optional
10
+ from jwt.exceptions import InvalidTokenError
11
+
12
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
13
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
14
+
15
+ # Authentication helper functions
16
+ def verify_password(plain_password, hashed_password):
17
+ return pwd_context.verify(plain_password, hashed_password)
18
+
19
+ def get_user(db, username: str):
20
+ if username in db:
21
+ user_dict = db[username]
22
+ return UserInDB(**user_dict)
23
+ return None
24
+
25
+ def authenticate_user(fake_db, username: str, password: str):
26
+ user = get_user(fake_db, username)
27
+ if not user:
28
+ return False
29
+ if not verify_password(password, user.hashed_password):
30
+ return False
31
+ return user
32
+
33
+ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
34
+ to_encode = data.copy()
35
+ if expires_delta:
36
+ expire = datetime.now(timezone.utc) + expires_delta
37
+ else:
38
+ expire = datetime.now(timezone.utc) + timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)
39
+ to_encode.update({"exp": expire})
40
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
41
+ return encoded_jwt
42
+
43
+ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
44
+ credentials_exception = HTTPException(
45
+ status_code=status.HTTP_401_UNAUTHORIZED,
46
+ detail="Could not validate credentials",
47
+ headers={"WWW-Authenticate": "Bearer"},
48
+ )
49
+ try:
50
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
51
+ username = payload.get("sub")
52
+ if username is None:
53
+ raise credentials_exception
54
+ token_data = TokenData(username=username)
55
+ except InvalidTokenError:
56
+ raise credentials_exception
57
+ user = get_user(users_db, username=token_data.username)
58
+ if user is None:
59
+ raise credentials_exception
60
+ return user
61
+
62
+ async def get_current_active_user(
63
+ current_user: Annotated[User, Depends(get_current_user)],
64
+ ):
65
+ if current_user.disabled:
66
+ raise HTTPException(status_code=400, detail="Inactive user")
67
+ return current_user
config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Security Config
4
+ SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
5
+ ALGORITHM = "HS256"
6
+ ACCESS_TOKEN_EXPIRE_HOURS = 24
7
+
8
+ # Paths
9
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
10
+ DATA_DIR = os.path.join(BASE_DIR, "data")
11
+ UPLOAD_DIR = os.path.join(BASE_DIR, "uploads")
12
+ OUTPUT_DIR = os.path.join(BASE_DIR, "outputs")
13
+ SUBJECT_DATA_FILE = os.path.join(DATA_DIR, "subjectData.csv")
14
+ SAMPLE_DATA_FILE = os.path.join(DATA_DIR, "sampleData.csv")
15
+ # Model Names
16
+ MODEL_NAME = "Detomo/cl-nagoya-sup-simcse-ja-for-standard-name-v1_0"
17
+ SETENCE_EMBEDDING_FILE = os.path.join(DATA_DIR, "sample_name_sentence_embeddings(cl-nagoya-sup-simcse-ja-for-standard-name-v1_1).pkl")
18
+ SETENCE_SIMILARITY_FILE = os.path.join(DATA_DIR, "sample_name_sentence_similarities(cl-nagoya-sup-simcse-ja-for-standard-name-v1_1).pkl")
19
+
database.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ users_db = {
2
+ "chien_vm": {
3
+ "username": "chien_vm",
4
+ "full_name": "Chien VM",
5
+ "email": "[email protected]",
6
+ "hashed_password": "$2b$12$RtcKFk7B3hKd7vYkwxdFN.eBXSiryQIRUG.OoJ07Pl9lzHNUkugMi",
7
+ "disabled": False,
8
+ },
9
+ "hoi_nv": {
10
+ "username": "hoi_nv",
11
+ "full_name": "Hoi NV",
12
+ "email": "[email protected]",
13
+ "hashed_password": "$2b$12$RtcKFk7B3hKd7vYkwxdFN.eBXSiryQIRUG.OoJ07Pl9lzHNUkugMi",
14
+ "disabled": False,
15
+ }
16
+ }
main.py CHANGED
@@ -1,318 +1,55 @@
1
  import sys
2
  import os
3
- import time
4
- from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, status
5
- from fastapi.responses import FileResponse
6
- from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
7
  import uvicorn
8
  import traceback
9
- import pickle
10
- import shutil
11
- from pathlib import Path
12
  from contextlib import asynccontextmanager
13
- import pandas as pd
14
- from typing import Annotated, Optional, Union
15
- from datetime import datetime, timedelta, timezone
16
- import jwt
17
- from jwt.exceptions import InvalidTokenError
18
- from passlib.context import CryptContext
19
- from pydantic import BaseModel
20
 
21
  current_dir = os.path.dirname(os.path.abspath(__file__))
22
  sys.path.append(os.path.join(current_dir, "meisai-check-ai"))
23
 
24
- from sentence_transformer_lib.sentence_transformer_helper import (
25
- SentenceTransformerHelper,
26
- )
27
- from data_lib.input_name_data import InputNameData
28
- from data_lib.subject_data import SubjectData
29
- from data_lib.sample_name_data import SampleNameData
30
- from clustering_lib.sentence_clustering_lib import SentenceClusteringLib
31
- from data_lib.base_data import (
32
- COL_STANDARD_NAME,
33
- COL_STANDARD_NAME_KEY,
34
- COL_STANDARD_SUBJECT,
35
- )
36
- from mapping_lib.name_mapping_helper import NameMappingHelper
37
-
38
- # Initialize global variables for model and data
39
- sentenceTransformerHelper = None
40
- dic_standard_subject = None
41
- sample_name_sentence_embeddings = None
42
- sample_name_sentence_similarities = None
43
- sampleData = None
44
- sentence_clustering_lib = None
45
- name_groups = None
46
-
47
- # Create data directory if it doesn't exist
48
- os.makedirs(os.path.join(current_dir, "data"), exist_ok=True)
49
- os.makedirs(os.path.join(current_dir, "uploads"), exist_ok=True)
50
- os.makedirs(os.path.join(current_dir, "outputs"), exist_ok=True)
51
-
52
- # Authentication related settings
53
- SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
54
- ALGORITHM = "HS256"
55
- ACCESS_TOKEN_EXPIRE_HOURS = 24 # Token expiration set to 24 hours
56
-
57
- # Password hashing context
58
- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
59
-
60
- # OAuth2 scheme for token
61
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
62
-
63
- # User database models
64
- class Token(BaseModel):
65
- access_token: str
66
- token_type: str
67
-
68
- class TokenData(BaseModel):
69
- username: Optional[str] = None
70
-
71
- class User(BaseModel):
72
- username: str
73
- email: Optional[str] = None
74
- full_name: Optional[str] = None
75
- disabled: Optional[bool] = None
76
-
77
- class UserInDB(User):
78
- hashed_password: str
79
-
80
- # Fake users database with hashed passwords
81
- users_db = {
82
- "chien_vm": {
83
- "username": "chien_vm",
84
- "full_name": "Chien VM",
85
- "email": "[email protected]",
86
- "hashed_password": "$2b$12$RtcKFk7B3hKd7vYkwxdFN.eBXSiryQIRUG.OoJ07Pl9lzHNUkugMi",
87
- "disabled": False,
88
- },
89
- "hoi_nv": {
90
- "username": "hoi_nv",
91
- "full_name": "Hoi NV",
92
- "email": "[email protected]",
93
- "hashed_password": "$2b$12$RtcKFk7B3hKd7vYkwxdFN.eBXSiryQIRUG.OoJ07Pl9lzHNUkugMi",
94
- "disabled": False,
95
- }
96
- }
97
 
98
- # Authentication helper functions
99
- def verify_password(plain_password, hashed_password):
100
- return pwd_context.verify(plain_password, hashed_password)
101
-
102
- def get_user(db, username: str):
103
- if username in db:
104
- user_dict = db[username]
105
- return UserInDB(**user_dict)
106
- return None
107
-
108
- def authenticate_user(fake_db, username: str, password: str):
109
- user = get_user(fake_db, username)
110
- if not user:
111
- return False
112
- if not verify_password(password, user.hashed_password):
113
- return False
114
- return user
115
-
116
- def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
117
- to_encode = data.copy()
118
- if expires_delta:
119
- expire = datetime.now(timezone.utc) + expires_delta
120
- else:
121
- expire = datetime.now(timezone.utc) + timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)
122
- to_encode.update({"exp": expire})
123
- encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
124
- return encoded_jwt
125
-
126
- async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
127
- credentials_exception = HTTPException(
128
- status_code=status.HTTP_401_UNAUTHORIZED,
129
- detail="Could not validate credentials",
130
- headers={"WWW-Authenticate": "Bearer"},
131
- )
132
- try:
133
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
134
- username = payload.get("sub")
135
- if username is None:
136
- raise credentials_exception
137
- token_data = TokenData(username=username)
138
- except InvalidTokenError:
139
- raise credentials_exception
140
- user = get_user(users_db, username=token_data.username)
141
- if user is None:
142
- raise credentials_exception
143
- return user
144
-
145
- async def get_current_active_user(
146
- current_user: Annotated[User, Depends(get_current_user)],
147
- ):
148
- if current_user.disabled:
149
- raise HTTPException(status_code=400, detail="Inactive user")
150
- return current_user
151
 
152
  @asynccontextmanager
153
  async def lifespan(app: FastAPI):
154
  """Lifespan context manager for startup and shutdown events"""
155
- global sentenceTransformerHelper, dic_standard_subject, sample_name_sentence_embeddings
156
- global sample_name_sentence_similarities, sampleData, sentence_clustering_lib, name_groups
157
-
158
  try:
159
- # Load sentence transformer model
160
- sentenceTransformerHelper = SentenceTransformerHelper(
161
- convert_to_zenkaku_flag=True, replace_words=None, keywords=None
162
- )
163
- sentenceTransformerHelper.load_model_by_name(
164
- "Detomo/cl-nagoya-sup-simcse-ja-for-standard-name-v1_0"
165
- )
166
-
167
- # Load standard subject dictionary
168
- dic_standard_subject = SubjectData.create_standard_subject_dic_from_file(
169
- "data/subjectData.csv"
170
- )
171
-
172
- # Load pre-computed embeddings and similarities
173
- with open(
174
- f"data/sample_name_sentence_embeddings(cl-nagoya-sup-simcse-ja-for-standard-name-v1_1).pkl",
175
- "rb",
176
- ) as f:
177
- sample_name_sentence_embeddings = pickle.load(f)
178
-
179
- with open(
180
- f"data/sample_name_sentence_similarities(cl-nagoya-sup-simcse-ja-for-standard-name-v1_1).pkl",
181
- "rb",
182
- ) as f:
183
- sample_name_sentence_similarities = pickle.load(f)
184
-
185
- # Load and process sample data
186
- sampleData = SampleNameData()
187
- file_path = os.path.join(current_dir, "data", "sampleData.csv")
188
- sampleData.load_data_from_csv(file_path)
189
- sampleData.process_data()
190
-
191
- # Create sentence clusters
192
- sentence_clustering_lib = SentenceClusteringLib(sample_name_sentence_embeddings)
193
- best_name_eps = 0.07
194
- name_groups, _ = sentence_clustering_lib.create_sentence_cluster(best_name_eps)
195
- sampleData._create_key_column(
196
- COL_STANDARD_NAME_KEY, COL_STANDARD_SUBJECT, COL_STANDARD_NAME
197
- )
198
- sampleData.set_name_sentence_labels(name_groups)
199
- sampleData.build_search_tree()
200
-
201
- print("Models and data loaded successfully")
202
  except Exception as e:
203
  print(f"Error during startup: {e}")
204
  traceback.print_exc()
205
 
206
- yield # This is where the app runs
207
 
208
- # Cleanup code (if needed) goes here
209
  print("Shutting down application")
210
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- app = FastAPI(lifespan=lifespan)
 
 
 
213
 
214
 
215
- @app.get("/")
216
  async def root():
217
  return {"message": "Hello World"}
218
 
219
-
220
- @app.get("/health")
221
- async def health_check():
222
- return {"status": "ok", "timestamp": time.time()}
223
-
224
-
225
- @app.post("/token")
226
- async def login_for_access_token(
227
- form_data: Annotated[OAuth2PasswordRequestForm, Depends()]
228
- ) -> Token:
229
- """
230
- Login endpoint to get an access token
231
- """
232
- user = authenticate_user(users_db, form_data.username, form_data.password)
233
- if not user:
234
- raise HTTPException(
235
- status_code=status.HTTP_401_UNAUTHORIZED,
236
- detail="Incorrect username or password",
237
- headers={"WWW-Authenticate": "Bearer"},
238
- )
239
- access_token_expires = timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)
240
- access_token = create_access_token(
241
- data={"sub": user.username}, expires_delta=access_token_expires
242
- )
243
- return Token(access_token=access_token, token_type="bearer")
244
-
245
-
246
- @app.post("/predict")
247
- async def predict(
248
- current_user: Annotated[User, Depends(get_current_active_user)],
249
- file: UploadFile = File(...)
250
- ):
251
- """
252
- Process an input CSV file and return standardized names (requires authentication)
253
- """
254
- global sentenceTransformerHelper, dic_standard_subject, sample_name_sentence_embeddings
255
- global sample_name_sentence_similarities, sampleData, name_groups
256
-
257
- if not file.filename.endswith(".csv"):
258
- raise HTTPException(status_code=400, detail="Only CSV files are supported")
259
-
260
- # Save uploaded file
261
- timestamp = int(time.time())
262
- input_file_path = os.path.join(current_dir, "uploads", f"input_{timestamp}_{current_user.username}.csv")
263
-
264
- # Use CSV format with correct extension
265
- output_file_path = os.path.join(current_dir, "outputs", f"output_{timestamp}_{current_user.username}.csv")
266
-
267
- try:
268
- with open(input_file_path, "wb") as buffer:
269
- shutil.copyfileobj(file.file, buffer)
270
- finally:
271
- file.file.close()
272
-
273
- try:
274
- # Process input data
275
- inputData = InputNameData(dic_standard_subject)
276
- inputData.load_data_from_csv(input_file_path)
277
- inputData.process_data()
278
-
279
- # Map standard names
280
- nameMappingHelper = NameMappingHelper(
281
- sentenceTransformerHelper,
282
- inputData,
283
- sampleData,
284
- sample_name_sentence_embeddings,
285
- sample_name_sentence_similarities,
286
- )
287
- df_predicted = nameMappingHelper.map_standard_names()
288
- # Create output dataframe and save to CSV
289
- print("Columns of inputData.dataframe", inputData.dataframe.columns)
290
- column_to_keep = ['シート名', '行', '科目', '分類', '名称', '摘要', '備考']
291
- output_df = inputData.dataframe[column_to_keep].copy()
292
- output_df.reset_index(drop=False, inplace=True)
293
- output_df.loc[:, "出力_科目"] = df_predicted["出力_科目"]
294
- output_df.loc[:, "出力_項目名"] = df_predicted["出力_項目名"]
295
- output_df.loc[:, "出力_確率度"] = df_predicted["出力_確率度"]
296
-
297
- # Save with utf_8_sig encoding for Japanese Excel compatibility
298
- output_df.to_csv(output_file_path, index=False, encoding="utf_8_sig")
299
-
300
- # Return the file as a download with correct content type and headers
301
- return FileResponse(
302
- path=output_file_path,
303
- filename=f"output_{Path(file.filename).stem}.csv",
304
- media_type="text/csv",
305
- headers={
306
- "Content-Disposition": f'attachment; filename="output_{Path(file.filename).stem}.csv"',
307
- "Content-Type": "application/x-www-form-urlencoded",
308
- },
309
- )
310
-
311
- except Exception as e:
312
- print(f"Error processing file: {e}")
313
- traceback.print_exc()
314
- raise HTTPException(status_code=500, detail=str(e))
315
-
316
-
317
  if __name__ == "__main__":
 
318
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import sys
2
  import os
3
+ from fastapi import FastAPI
 
 
 
4
  import uvicorn
5
  import traceback
 
 
 
6
  from contextlib import asynccontextmanager
 
 
 
 
 
 
 
7
 
8
  current_dir = os.path.dirname(os.path.abspath(__file__))
9
  sys.path.append(os.path.join(current_dir, "meisai-check-ai"))
10
 
11
+ from routes import auth, predict, health
12
+ from services.sentence_transformer_service import sentence_transformer_service
13
+ from utils import create_directories
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  @asynccontextmanager
17
  async def lifespan(app: FastAPI):
18
  """Lifespan context manager for startup and shutdown events"""
 
 
 
19
  try:
20
+ # Load models and data ONCE at startup
21
+ sentence_transformer_service.load_model_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  except Exception as e:
23
  print(f"Error during startup: {e}")
24
  traceback.print_exc()
25
 
26
+ yield # App chạy tại đây
27
 
 
28
  print("Shutting down application")
29
 
30
+ # Initialize FastAPI
31
+ app = FastAPI(
32
+ title="MeisaiCheck API",
33
+ description="API for MeisaiCheck AI System",
34
+ version="1.0",
35
+ lifespan=lifespan,
36
+ openapi_tags=[
37
+ {"name": "Health", "description": "Health check endpoints"},
38
+ {"name": "Authentication", "description": "User authentication and token management"},
39
+ {"name": "Prediction", "description": " Predict and process CSV files"},
40
+ ]
41
+ )
42
 
43
+ # Include Routers
44
+ app.include_router(health.router, tags=["Health"])
45
+ app.include_router(auth.router, tags=["Authentication"])
46
+ app.include_router(predict.router, tags=["Prediction"])
47
 
48
 
49
+ @app.get("/", tags=["Health"])
50
  async def root():
51
  return {"message": "Hello World"}
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  if __name__ == "__main__":
54
+ create_directories()
55
  uvicorn.run(app, host="0.0.0.0", port=8000)
models.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import Optional
3
+
4
+ class Token(BaseModel):
5
+ access_token: str
6
+ token_type: str
7
+
8
+ class TokenData(BaseModel):
9
+ username: Optional[str] = None
10
+
11
+ class User(BaseModel):
12
+ username: str
13
+ email: Optional[str] = None
14
+ full_name: Optional[str] = None
15
+ disabled: Optional[bool] = None
16
+
17
+ class UserInDB(User):
18
+ hashed_password: str
routes/__init__.py ADDED
File without changes
routes/auth.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException, status
2
+ from fastapi.security import OAuth2PasswordRequestForm
3
+ from datetime import timedelta
4
+ from auth import authenticate_user, create_access_token
5
+ from models import Token
6
+ from config import ACCESS_TOKEN_EXPIRE_HOURS
7
+ from database import users_db
8
+
9
+ router = APIRouter()
10
+
11
+ @router.post("/token", response_model=Token)
12
+ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
13
+ """
14
+ Endpoint để lấy access token bằng username và password
15
+ """
16
+ user = authenticate_user(users_db, form_data.username, form_data.password)
17
+ if not user:
18
+ raise HTTPException(
19
+ status_code=status.HTTP_401_UNAUTHORIZED,
20
+ detail="Incorrect username or password",
21
+ headers={"WWW-Authenticate": "Bearer"},
22
+ )
23
+
24
+ access_token_expires = timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)
25
+ access_token = create_access_token(
26
+ data={"sub": user.username}, expires_delta=access_token_expires
27
+ )
28
+ return Token(access_token=access_token, token_type="bearer")
routes/health.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ import time
3
+
4
+ router = APIRouter()
5
+
6
+ @router.get("/health")
7
+ async def health_check():
8
+ return {"status": "ok", "timestamp": time.time()}
routes/predict.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import shutil
4
+ from pathlib import Path
5
+ from fastapi import APIRouter, UploadFile, File, HTTPException, Depends
6
+ from fastapi.responses import FileResponse
7
+ from auth import get_current_user
8
+ from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service
9
+ from data_lib.input_name_data import InputNameData
10
+ from mapping_lib.name_mapping_helper import NameMappingHelper
11
+ from config import UPLOAD_DIR, OUTPUT_DIR
12
+
13
+ router = APIRouter()
14
+
15
+ @router.post("/predict")
16
+ async def predict(
17
+ current_user=Depends(get_current_user),
18
+ file: UploadFile = File(...),
19
+ sentence_service: SentenceTransformerService = Depends(lambda: sentence_transformer_service)
20
+ ):
21
+ """
22
+ Process an input CSV file and return standardized names (requires authentication)
23
+ """
24
+ if not file.filename.endswith(".csv"):
25
+ raise HTTPException(status_code=400, detail="Only CSV files are supported")
26
+
27
+ # Save uploaded file
28
+ timestamp = int(time.time())
29
+ input_file_path = os.path.join(UPLOAD_DIR, f"input_{timestamp}_{current_user.username}.csv")
30
+ output_file_path = os.path.join(OUTPUT_DIR, f"output_{timestamp}_{current_user.username}.csv")
31
+
32
+ try:
33
+ with open(input_file_path, "wb") as buffer:
34
+ shutil.copyfileobj(file.file, buffer)
35
+ finally:
36
+ file.file.close()
37
+
38
+ try:
39
+ # Process input data
40
+ inputData = InputNameData(sentence_service.dic_standard_subject)
41
+ inputData.load_data_from_csv(input_file_path)
42
+ inputData.process_data()
43
+
44
+ # Map standard names
45
+ nameMappingHelper = NameMappingHelper(
46
+ sentence_service.sentenceTransformerHelper,
47
+ inputData,
48
+ sentence_service.sampleData,
49
+ sentence_service.sample_name_sentence_embeddings,
50
+ sentence_service.sample_name_sentence_similarities,
51
+ )
52
+ df_predicted = nameMappingHelper.map_standard_names()
53
+
54
+ # Create output dataframe and save to CSV
55
+ column_to_keep = ['シート名', '行', '科目', '分類', '名称', '摘要', '備考']
56
+ output_df = inputData.dataframe[column_to_keep].copy()
57
+ output_df.reset_index(drop=False, inplace=True)
58
+ output_df.loc[:, "出力_科目"] = df_predicted["出力_科目"]
59
+ output_df.loc[:, "出力_項目名"] = df_predicted["出力_項目名"]
60
+ output_df.loc[:, "出力_確率度"] = df_predicted["出力_確率度"]
61
+
62
+ # Save with utf_8_sig encoding for Japanese Excel compatibility
63
+ output_df.to_csv(output_file_path, index=False, encoding="utf_8_sig")
64
+
65
+ return FileResponse(
66
+ path=output_file_path,
67
+ filename=f"output_{Path(file.filename).stem}.csv",
68
+ media_type="text/csv",
69
+ headers={
70
+ "Content-Disposition": f'attachment; filename="output_{Path(file.filename).stem}.csv"',
71
+ "Content-Type": "application/x-www-form-urlencoded",
72
+ },
73
+ )
74
+
75
+ except Exception as e:
76
+ print(f"Error processing file: {e}")
77
+ raise HTTPException(status_code=500, detail=str(e))
services/__init__.py ADDED
File without changes
services/sentence_transformer_service.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from config import (
3
+ MODEL_NAME,
4
+ SETENCE_EMBEDDING_FILE,
5
+ SETENCE_SIMILARITY_FILE,
6
+ SAMPLE_DATA_FILE, SUBJECT_DATA_FILE
7
+ )
8
+ from sentence_transformer_lib.sentence_transformer_helper import SentenceTransformerHelper
9
+ from data_lib.subject_data import SubjectData
10
+ from data_lib.sample_name_data import SampleNameData
11
+ from clustering_lib.sentence_clustering_lib import SentenceClusteringLib
12
+ from data_lib.base_data import COL_STANDARD_NAME_KEY, COL_STANDARD_SUBJECT, COL_STANDARD_NAME
13
+
14
+ class SentenceTransformerService:
15
+ def __init__(self):
16
+ self.sentenceTransformerHelper = None
17
+ self.dic_standard_subject = None
18
+ self.sample_name_sentence_embeddings = None
19
+ self.sample_name_sentence_similarities = None
20
+ self.sampleData = None
21
+ self.sentence_clustering_lib = None
22
+ self.name_groups = None
23
+
24
+ def load_model_data(self):
25
+ """Load model and data only once at startup"""
26
+ if self.sentenceTransformerHelper is not None:
27
+ print("Model already loaded. Skipping reload.")
28
+ return # Không load lại nếu đã có model
29
+
30
+ print("Loading models and data...")
31
+ # Load sentence transformer model
32
+ self.sentenceTransformerHelper = SentenceTransformerHelper(
33
+ convert_to_zenkaku_flag=True, replace_words=None, keywords=None
34
+ )
35
+ self.sentenceTransformerHelper.load_model_by_name(MODEL_NAME)
36
+
37
+ # Load standard subject dictionary
38
+ self.dic_standard_subject = SubjectData.create_standard_subject_dic_from_file(SUBJECT_DATA_FILE)
39
+
40
+ # Load pre-computed embeddings and similarities
41
+ with open(SETENCE_EMBEDDING_FILE, "rb") as f:
42
+ self.sample_name_sentence_embeddings = pickle.load(f)
43
+
44
+ with open(SETENCE_SIMILARITY_FILE, "rb") as f:
45
+ self.sample_name_sentence_similarities = pickle.load(f)
46
+
47
+ # Load and process sample data
48
+ self.sampleData = SampleNameData()
49
+ self.sampleData.load_data_from_csv(SAMPLE_DATA_FILE)
50
+ self.sampleData.process_data()
51
+
52
+ # Create sentence clusters
53
+ self.sentence_clustering_lib = SentenceClusteringLib(self.sample_name_sentence_embeddings)
54
+ best_name_eps = 0.07
55
+ self.name_groups, _ = self.sentence_clustering_lib.create_sentence_cluster(best_name_eps)
56
+
57
+ self.sampleData._create_key_column(
58
+ COL_STANDARD_NAME_KEY, COL_STANDARD_SUBJECT, COL_STANDARD_NAME
59
+ )
60
+ self.sampleData.set_name_sentence_labels(self.name_groups)
61
+ self.sampleData.build_search_tree()
62
+
63
+ print("Models and data loaded successfully")
64
+
65
+ # Global instance (singleton)
66
+ sentence_transformer_service = SentenceTransformerService()
utils.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+ from config import DATA_DIR, UPLOAD_DIR, OUTPUT_DIR
3
+
4
+ def create_directories():
5
+ os.makedirs(DATA_DIR, exist_ok=True)
6
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
7
+ os.makedirs(OUTPUT_DIR, exist_ok=True)