vumichien commited on
Commit
bf7f5ee
·
1 Parent(s): cfe7921

require token to use API

Browse files
Files changed (2) hide show
  1. main.py +137 -9
  2. requirements.txt +0 -0
main.py CHANGED
@@ -1,8 +1,9 @@
1
  import sys
2
  import os
3
  import time
4
- from fastapi import FastAPI, UploadFile, File, HTTPException
5
  from fastapi.responses import FileResponse
 
6
  import uvicorn
7
  import traceback
8
  import pickle
@@ -10,6 +11,12 @@ import shutil
10
  from pathlib import Path
11
  from contextlib import asynccontextmanager
12
  import pandas as pd
 
 
 
 
 
 
13
 
14
  current_dir = os.path.dirname(os.path.abspath(__file__))
15
  sys.path.append(os.path.join(current_dir, "meisai-check-ai"))
@@ -42,6 +49,105 @@ os.makedirs(os.path.join(current_dir, "data"), exist_ok=True)
42
  os.makedirs(os.path.join(current_dir, "uploads"), exist_ok=True)
43
  os.makedirs(os.path.join(current_dir, "outputs"), exist_ok=True)
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  @asynccontextmanager
47
  async def lifespan(app: FastAPI):
@@ -116,10 +222,34 @@ async def health_check():
116
  return {"status": "ok", "timestamp": time.time()}
117
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  @app.post("/predict")
120
- async def predict(file: UploadFile = File(...)):
 
 
 
121
  """
122
- Process an input CSV file and return standardized names
123
  """
124
  global sentenceTransformerHelper, dic_standard_subject, sample_name_sentence_embeddings
125
  global sample_name_sentence_similarities, sampleData, name_groups
@@ -129,10 +259,10 @@ async def predict(file: UploadFile = File(...)):
129
 
130
  # Save uploaded file
131
  timestamp = int(time.time())
132
- input_file_path = os.path.join(current_dir, "uploads", f"input_{timestamp}.csv")
133
 
134
  # Use CSV format with correct extension
135
- output_file_path = os.path.join(current_dir, "outputs", f"output_{timestamp}.csv")
136
 
137
  try:
138
  with open(input_file_path, "wb") as buffer:
@@ -158,9 +288,7 @@ async def predict(file: UploadFile = File(...)):
158
  # Create output dataframe and save to CSV
159
  print("Columns of inputData.dataframe", inputData.dataframe.columns)
160
  inputData.dataframe.reset_index(drop=False, inplace=True)
161
- columns_to_keep = ["ID", "シート名", "行", "科目", "名称", "摘要", "備考"]
162
- output_df = inputData.dataframe[columns_to_keep].copy()
163
- # Use .loc to avoid SettingWithCopyWarning
164
  output_df.loc[:, "出力_科目"] = df_predicted["出力_科目"]
165
  output_df.loc[:, "出力_項目名"] = df_predicted["出力_項目名"]
166
  output_df.loc[:, "出力_確率度"] = df_predicted["出力_確率度"]
@@ -184,6 +312,6 @@ async def predict(file: UploadFile = File(...)):
184
  traceback.print_exc()
185
  raise HTTPException(status_code=500, detail=str(e))
186
 
187
-
188
  if __name__ == "__main__":
189
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import sys
2
  import os
3
  import time
4
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, status
5
  from fastapi.responses import FileResponse
6
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
7
  import uvicorn
8
  import traceback
9
  import pickle
 
11
  from pathlib import Path
12
  from contextlib import asynccontextmanager
13
  import pandas as pd
14
+ from typing import Annotated
15
+ from datetime import datetime, timedelta, timezone
16
+ import jwt
17
+ from jwt.exceptions import InvalidTokenError
18
+ from passlib.context import CryptContext
19
+ from pydantic import BaseModel
20
 
21
  current_dir = os.path.dirname(os.path.abspath(__file__))
22
  sys.path.append(os.path.join(current_dir, "meisai-check-ai"))
 
49
  os.makedirs(os.path.join(current_dir, "uploads"), exist_ok=True)
50
  os.makedirs(os.path.join(current_dir, "outputs"), exist_ok=True)
51
 
52
+ # Authentication related settings
53
+ SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
54
+ ALGORITHM = "HS256"
55
+ ACCESS_TOKEN_EXPIRE_HOURS = 24 # Token expiration set to 24 hours
56
+
57
+ # Password hashing context
58
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
59
+
60
+ # OAuth2 scheme for token
61
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
62
+
63
+ # User database models
64
+ class Token(BaseModel):
65
+ access_token: str
66
+ token_type: str
67
+
68
+ class TokenData(BaseModel):
69
+ username: str | None = None
70
+
71
+ class User(BaseModel):
72
+ username: str
73
+ email: str | None = None
74
+ full_name: str | None = None
75
+ disabled: bool | None = None
76
+
77
+ class UserInDB(User):
78
+ hashed_password: str
79
+
80
+ # Fake users database with hashed passwords
81
+ users_db = {
82
+ "chien_vm": {
83
+ "username": "chien_vm",
84
+ "full_name": "Chien VM",
85
+ "email": "[email protected]",
86
+ "hashed_password": "$2b$12$RtcKFk7B3hKd7vYkwxdFN.eBXSiryQIRUG.OoJ07Pl9lzHNUkugMi",
87
+ "disabled": False,
88
+ },
89
+ "hoi_nv": {
90
+ "username": "hoi_nv",
91
+ "full_name": "Hoi NV",
92
+ "email": "[email protected]",
93
+ "hashed_password": "$2b$12$RtcKFk7B3hKd7vYkwxdFN.eBXSiryQIRUG.OoJ07Pl9lzHNUkugMi",
94
+ "disabled": False,
95
+ }
96
+ }
97
+
98
+ # Authentication helper functions
99
+ def verify_password(plain_password, hashed_password):
100
+ return pwd_context.verify(plain_password, hashed_password)
101
+
102
+ def get_user(db, username: str):
103
+ if username in db:
104
+ user_dict = db[username]
105
+ return UserInDB(**user_dict)
106
+ return None
107
+
108
+ def authenticate_user(fake_db, username: str, password: str):
109
+ user = get_user(fake_db, username)
110
+ if not user:
111
+ return False
112
+ if not verify_password(password, user.hashed_password):
113
+ return False
114
+ return user
115
+
116
+ def create_access_token(data: dict, expires_delta: timedelta | None = None):
117
+ to_encode = data.copy()
118
+ if expires_delta:
119
+ expire = datetime.now(timezone.utc) + expires_delta
120
+ else:
121
+ expire = datetime.now(timezone.utc) + timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)
122
+ to_encode.update({"exp": expire})
123
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
124
+ return encoded_jwt
125
+
126
+ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
127
+ credentials_exception = HTTPException(
128
+ status_code=status.HTTP_401_UNAUTHORIZED,
129
+ detail="Could not validate credentials",
130
+ headers={"WWW-Authenticate": "Bearer"},
131
+ )
132
+ try:
133
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
134
+ username = payload.get("sub")
135
+ if username is None:
136
+ raise credentials_exception
137
+ token_data = TokenData(username=username)
138
+ except InvalidTokenError:
139
+ raise credentials_exception
140
+ user = get_user(users_db, username=token_data.username)
141
+ if user is None:
142
+ raise credentials_exception
143
+ return user
144
+
145
+ async def get_current_active_user(
146
+ current_user: Annotated[User, Depends(get_current_user)],
147
+ ):
148
+ if current_user.disabled:
149
+ raise HTTPException(status_code=400, detail="Inactive user")
150
+ return current_user
151
 
152
  @asynccontextmanager
153
  async def lifespan(app: FastAPI):
 
222
  return {"status": "ok", "timestamp": time.time()}
223
 
224
 
225
+ @app.post("/token")
226
+ async def login_for_access_token(
227
+ form_data: Annotated[OAuth2PasswordRequestForm, Depends()]
228
+ ) -> Token:
229
+ """
230
+ Login endpoint to get an access token
231
+ """
232
+ user = authenticate_user(users_db, form_data.username, form_data.password)
233
+ if not user:
234
+ raise HTTPException(
235
+ status_code=status.HTTP_401_UNAUTHORIZED,
236
+ detail="Incorrect username or password",
237
+ headers={"WWW-Authenticate": "Bearer"},
238
+ )
239
+ access_token_expires = timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)
240
+ access_token = create_access_token(
241
+ data={"sub": user.username}, expires_delta=access_token_expires
242
+ )
243
+ return Token(access_token=access_token, token_type="bearer")
244
+
245
+
246
  @app.post("/predict")
247
+ async def predict(
248
+ current_user: Annotated[User, Depends(get_current_active_user)],
249
+ file: UploadFile = File(...)
250
+ ):
251
  """
252
+ Process an input CSV file and return standardized names (requires authentication)
253
  """
254
  global sentenceTransformerHelper, dic_standard_subject, sample_name_sentence_embeddings
255
  global sample_name_sentence_similarities, sampleData, name_groups
 
259
 
260
  # Save uploaded file
261
  timestamp = int(time.time())
262
+ input_file_path = os.path.join(current_dir, "uploads", f"input_{timestamp}_{current_user.username}.csv")
263
 
264
  # Use CSV format with correct extension
265
+ output_file_path = os.path.join(current_dir, "outputs", f"output_{timestamp}_{current_user.username}.csv")
266
 
267
  try:
268
  with open(input_file_path, "wb") as buffer:
 
288
  # Create output dataframe and save to CSV
289
  print("Columns of inputData.dataframe", inputData.dataframe.columns)
290
  inputData.dataframe.reset_index(drop=False, inplace=True)
291
+ output_df = inputData.dataframe.copy()
 
 
292
  output_df.loc[:, "出力_科目"] = df_predicted["出力_科目"]
293
  output_df.loc[:, "出力_項目名"] = df_predicted["出力_項目名"]
294
  output_df.loc[:, "出力_確率度"] = df_predicted["出力_確率度"]
 
312
  traceback.print_exc()
313
  raise HTTPException(status_code=500, detail=str(e))
314
 
315
+
316
  if __name__ == "__main__":
317
  uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ