Spaces:
Running
Running
change logic from sentence name to representative name
Browse files- auth.py +2 -2
- config.py +4 -3
- data/sample_name_sentence_similarities(cl-nagoya-sup-simcse-ja-for-standard-name-v1_1).pkl +0 -3
- data/{sample_name_sentence_similarities(cl-nagoya-sup-simcse-ja-base).pkl → sample_representative_name_sentence_embeddings(cl-nagoya-sup-simcse-ja-for-standard-name-v1_0).pkl} +2 -2
- data/users.json +16 -0
- database.py +71 -14
- main.py +1 -1
- models.py +5 -0
- requirements.txt +0 -0
- routes/auth.py +14 -3
- routes/predict.py +11 -1
- services/sentence_transformer_service.py +3 -21
auth.py
CHANGED
@@ -5,7 +5,7 @@ from fastapi.security import OAuth2PasswordBearer
|
|
5 |
from passlib.context import CryptContext
|
6 |
from config import SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_HOURS
|
7 |
from models import TokenData, UserInDB, User
|
8 |
-
from database import
|
9 |
from typing import Annotated, Optional
|
10 |
from jwt.exceptions import InvalidTokenError
|
11 |
|
@@ -54,7 +54,7 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
|
|
54 |
token_data = TokenData(username=username)
|
55 |
except InvalidTokenError:
|
56 |
raise credentials_exception
|
57 |
-
user = get_user(
|
58 |
if user is None:
|
59 |
raise credentials_exception
|
60 |
return user
|
|
|
5 |
from passlib.context import CryptContext
|
6 |
from config import SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_HOURS
|
7 |
from models import TokenData, UserInDB, User
|
8 |
+
from database import get_users
|
9 |
from typing import Annotated, Optional
|
10 |
from jwt.exceptions import InvalidTokenError
|
11 |
|
|
|
54 |
token_data = TokenData(username=username)
|
55 |
except InvalidTokenError:
|
56 |
raise credentials_exception
|
57 |
+
user = get_user(get_users(), username=token_data.username)
|
58 |
if user is None:
|
59 |
raise credentials_exception
|
60 |
return user
|
config.py
CHANGED
@@ -14,6 +14,7 @@ SUBJECT_DATA_FILE = os.path.join(DATA_DIR, "subjectData.csv")
|
|
14 |
SAMPLE_DATA_FILE = os.path.join(DATA_DIR, "sampleData.csv")
|
15 |
# Model Names
|
16 |
MODEL_NAME = "Detomo/cl-nagoya-sup-simcse-ja-for-standard-name-v1_0"
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
14 |
SAMPLE_DATA_FILE = os.path.join(DATA_DIR, "sampleData.csv")
|
15 |
# Model Names
|
16 |
MODEL_NAME = "Detomo/cl-nagoya-sup-simcse-ja-for-standard-name-v1_0"
|
17 |
+
SENTENCE_EMBEDDING_FILE = os.path.join(
|
18 |
+
DATA_DIR,
|
19 |
+
"sample_representative_name_sentence_embeddings(cl-nagoya-sup-simcse-ja-for-standard-name-v1_0).pkl",
|
20 |
+
)
|
data/sample_name_sentence_similarities(cl-nagoya-sup-simcse-ja-for-standard-name-v1_1).pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:d942620d2940849fdee0f6cec443a5dd1f7f608144d4f1cee5ff66dd39797035
|
3 |
-
size 137593306
|
|
|
|
|
|
|
|
data/{sample_name_sentence_similarities(cl-nagoya-sup-simcse-ja-base).pkl → sample_representative_name_sentence_embeddings(cl-nagoya-sup-simcse-ja-for-standard-name-v1_0).pkl}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa42e9df217a42147a4a1e2a584a35462756f9f34646a6db981941cf89dc2095
|
3 |
+
size 18217123
|
data/users.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"chien_vm": {
|
3 |
+
"username": "chien_vm",
|
4 |
+
"full_name": "Chien VM",
|
5 |
+
"email": "[email protected]",
|
6 |
+
"hashed_password": "$2b$12$RtcKFk7B3hKd7vYkwxdFN.eBXSiryQIRUG.OoJ07Pl9lzHNUkugMi",
|
7 |
+
"disabled": false
|
8 |
+
},
|
9 |
+
"meiseidev": {
|
10 |
+
"username": "meiseidev",
|
11 |
+
"full_name": "meiseidev",
|
12 |
+
"email": "",
|
13 |
+
"hashed_password": "$2b$12$LXIbdQ388dMiN/ej76zuHeHuuz.VHz9rJfH4FpwdVbqfwCbSI55Va",
|
14 |
+
"disabled": false
|
15 |
+
}
|
16 |
+
}
|
database.py
CHANGED
@@ -1,16 +1,73 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
"disabled": False,
|
15 |
}
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from passlib.context import CryptContext
|
4 |
+
|
5 |
+
# Tạo thư mục data nếu chưa tồn tại
|
6 |
+
os.makedirs("data", exist_ok=True)
|
7 |
+
USERS_FILE = "data/users.json"
|
8 |
+
|
9 |
+
# Context để hash password
|
10 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
11 |
+
|
12 |
+
|
13 |
+
# Hàm để đọc users từ file JSON
|
14 |
+
def get_users():
|
15 |
+
if not os.path.exists(USERS_FILE):
|
16 |
+
# Tạo file với dữ liệu mặc định nếu chưa tồn tại
|
17 |
+
default_users = {
|
18 |
+
"chien_vm": {
|
19 |
+
"username": "chien_vm",
|
20 |
+
"full_name": "Chien VM",
|
21 |
+
"email": "[email protected]",
|
22 |
+
"hashed_password": "$2b$12$RtcKFk7B3hKd7vYkwxdFN.eBXSiryQIRUG.OoJ07Pl9lzHNUkugMi",
|
23 |
+
"disabled": False,
|
24 |
+
},
|
25 |
+
"hoi_nv": {
|
26 |
+
"username": "hoi_nv",
|
27 |
+
"full_name": "Hoi NV",
|
28 |
+
"email": "[email protected]",
|
29 |
+
"hashed_password": "$2b$12$RtcKFk7B3hKd7vYkwxdFN.eBXSiryQIRUG.OoJ07Pl9lzHNUkugMi",
|
30 |
+
"disabled": False,
|
31 |
+
},
|
32 |
+
}
|
33 |
+
save_users(default_users)
|
34 |
+
return default_users
|
35 |
+
|
36 |
+
with open(USERS_FILE, "r", encoding="utf-8") as f:
|
37 |
+
return json.load(f)
|
38 |
+
|
39 |
+
|
40 |
+
# Hàm để lưu users vào file JSON
|
41 |
+
def save_users(users_data):
|
42 |
+
with open(USERS_FILE, "w", encoding="utf-8") as f:
|
43 |
+
json.dump(users_data, f, indent=4, ensure_ascii=False)
|
44 |
+
|
45 |
+
|
46 |
+
# Hàm để tạo tài khoản mới
|
47 |
+
def create_account(username, password):
|
48 |
+
# Kiểm tra xem username đã tồn tại chưa
|
49 |
+
users = get_users()
|
50 |
+
if username in users:
|
51 |
+
return False, "Username already exists"
|
52 |
+
|
53 |
+
# Hash password
|
54 |
+
hashed_password = pwd_context.hash(password)
|
55 |
+
|
56 |
+
# Tạo user mới
|
57 |
+
new_user = {
|
58 |
+
"username": username,
|
59 |
+
"full_name": username, # Mặc định full_name là username
|
60 |
+
"email": "", # Không yêu cầu email
|
61 |
+
"hashed_password": hashed_password,
|
62 |
"disabled": False,
|
63 |
}
|
64 |
+
|
65 |
+
# Thêm user mới vào database
|
66 |
+
users[username] = new_user
|
67 |
+
save_users(users)
|
68 |
+
|
69 |
+
return True, "Account created successfully"
|
70 |
+
|
71 |
+
|
72 |
+
# Để tương thích với code cũ
|
73 |
+
users_db = get_users()
|
main.py
CHANGED
@@ -36,7 +36,7 @@ app = FastAPI(
|
|
36 |
openapi_tags=[
|
37 |
{"name": "Health", "description": "Health check endpoints"},
|
38 |
{"name": "Authentication", "description": "User authentication and token management"},
|
39 |
-
{"name": "Prediction", "description": "
|
40 |
]
|
41 |
)
|
42 |
|
|
|
36 |
openapi_tags=[
|
37 |
{"name": "Health", "description": "Health check endpoints"},
|
38 |
{"name": "Authentication", "description": "User authentication and token management"},
|
39 |
+
{"name": "Prediction", "description": "Predict and process CSV files"},
|
40 |
]
|
41 |
)
|
42 |
|
models.py
CHANGED
@@ -16,3 +16,8 @@ class User(BaseModel):
|
|
16 |
|
17 |
class UserInDB(User):
|
18 |
hashed_password: str
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
class UserInDB(User):
|
18 |
hashed_password: str
|
19 |
+
|
20 |
+
|
21 |
+
class UserCreate(BaseModel):
|
22 |
+
username: str
|
23 |
+
password: str
|
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|
routes/auth.py
CHANGED
@@ -2,9 +2,9 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
|
2 |
from fastapi.security import OAuth2PasswordRequestForm
|
3 |
from datetime import timedelta
|
4 |
from auth import authenticate_user, create_access_token
|
5 |
-
from models import Token
|
6 |
from config import ACCESS_TOKEN_EXPIRE_HOURS
|
7 |
-
from database import
|
8 |
|
9 |
router = APIRouter()
|
10 |
|
@@ -13,7 +13,7 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(
|
|
13 |
"""
|
14 |
Endpoint để lấy access token bằng username và password
|
15 |
"""
|
16 |
-
user = authenticate_user(
|
17 |
if not user:
|
18 |
raise HTTPException(
|
19 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
@@ -26,3 +26,14 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(
|
|
26 |
data={"sub": user.username}, expires_delta=access_token_expires
|
27 |
)
|
28 |
return Token(access_token=access_token, token_type="bearer")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from fastapi.security import OAuth2PasswordRequestForm
|
3 |
from datetime import timedelta
|
4 |
from auth import authenticate_user, create_access_token
|
5 |
+
from models import Token, UserCreate
|
6 |
from config import ACCESS_TOKEN_EXPIRE_HOURS
|
7 |
+
from database import get_users, create_account
|
8 |
|
9 |
router = APIRouter()
|
10 |
|
|
|
13 |
"""
|
14 |
Endpoint để lấy access token bằng username và password
|
15 |
"""
|
16 |
+
user = authenticate_user(get_users(), form_data.username, form_data.password)
|
17 |
if not user:
|
18 |
raise HTTPException(
|
19 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
26 |
data={"sub": user.username}, expires_delta=access_token_expires
|
27 |
)
|
28 |
return Token(access_token=access_token, token_type="bearer")
|
29 |
+
|
30 |
+
|
31 |
+
@router.post("/register")
|
32 |
+
async def register_user(user_data: UserCreate):
|
33 |
+
"""
|
34 |
+
Endpoint để đăng ký tài khoản mới
|
35 |
+
"""
|
36 |
+
success, message = create_account(user_data.username, user_data.password)
|
37 |
+
if not success:
|
38 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message)
|
39 |
+
return {"message": message}
|
routes/predict.py
CHANGED
@@ -7,6 +7,7 @@ from fastapi.responses import FileResponse
|
|
7 |
from auth import get_current_user
|
8 |
from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service
|
9 |
from data_lib.input_name_data import InputNameData
|
|
|
10 |
from mapping_lib.name_mapping_helper import NameMappingHelper
|
11 |
from config import UPLOAD_DIR, OUTPUT_DIR
|
12 |
|
@@ -40,14 +41,23 @@ async def predict(
|
|
40 |
inputData = InputNameData(sentence_service.dic_standard_subject)
|
41 |
inputData.load_data_from_csv(input_file_path)
|
42 |
inputData.process_data()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
# Map standard names
|
45 |
nameMappingHelper = NameMappingHelper(
|
46 |
sentence_service.sentenceTransformerHelper,
|
47 |
inputData,
|
48 |
sentence_service.sampleData,
|
|
|
49 |
sentence_service.sample_name_sentence_embeddings,
|
50 |
-
|
51 |
)
|
52 |
df_predicted = nameMappingHelper.map_standard_names()
|
53 |
|
|
|
7 |
from auth import get_current_user
|
8 |
from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service
|
9 |
from data_lib.input_name_data import InputNameData
|
10 |
+
from data_lib.base_data import COL_NAME_SENTENCE
|
11 |
from mapping_lib.name_mapping_helper import NameMappingHelper
|
12 |
from config import UPLOAD_DIR, OUTPUT_DIR
|
13 |
|
|
|
41 |
inputData = InputNameData(sentence_service.dic_standard_subject)
|
42 |
inputData.load_data_from_csv(input_file_path)
|
43 |
inputData.process_data()
|
44 |
+
input_name_sentences = inputData.dataframe[COL_NAME_SENTENCE]
|
45 |
+
input_name_sentence_embeddings = sentence_service.sentenceTransformerHelper.create_embeddings(input_name_sentences)
|
46 |
+
|
47 |
+
# Create similarity matrix
|
48 |
+
similarity_matrix = sentence_service.sentenceTransformerHelper.create_similarity_matrix_from_embeddings(
|
49 |
+
sentence_service.sample_name_sentence_embeddings,
|
50 |
+
input_name_sentence_embeddings
|
51 |
+
)
|
52 |
|
53 |
# Map standard names
|
54 |
nameMappingHelper = NameMappingHelper(
|
55 |
sentence_service.sentenceTransformerHelper,
|
56 |
inputData,
|
57 |
sentence_service.sampleData,
|
58 |
+
input_name_sentence_embeddings,
|
59 |
sentence_service.sample_name_sentence_embeddings,
|
60 |
+
similarity_matrix,
|
61 |
)
|
62 |
df_predicted = nameMappingHelper.map_standard_names()
|
63 |
|
services/sentence_transformer_service.py
CHANGED
@@ -1,25 +1,20 @@
|
|
1 |
import pickle
|
2 |
from config import (
|
3 |
MODEL_NAME,
|
4 |
-
|
5 |
-
SETENCE_SIMILARITY_FILE,
|
6 |
SAMPLE_DATA_FILE, SUBJECT_DATA_FILE
|
7 |
)
|
8 |
from sentence_transformer_lib.sentence_transformer_helper import SentenceTransformerHelper
|
9 |
from data_lib.subject_data import SubjectData
|
10 |
from data_lib.sample_name_data import SampleNameData
|
11 |
-
|
12 |
-
from data_lib.base_data import COL_STANDARD_NAME_KEY, COL_STANDARD_SUBJECT, COL_STANDARD_NAME
|
13 |
|
14 |
class SentenceTransformerService:
|
15 |
def __init__(self):
|
16 |
self.sentenceTransformerHelper = None
|
17 |
self.dic_standard_subject = None
|
18 |
self.sample_name_sentence_embeddings = None
|
19 |
-
self.sample_name_sentence_similarities = None
|
20 |
self.sampleData = None
|
21 |
-
self.sentence_clustering_lib = None
|
22 |
-
self.name_groups = None
|
23 |
|
24 |
def load_model_data(self):
|
25 |
"""Load model and data only once at startup"""
|
@@ -38,26 +33,13 @@ class SentenceTransformerService:
|
|
38 |
self.dic_standard_subject = SubjectData.create_standard_subject_dic_from_file(SUBJECT_DATA_FILE)
|
39 |
|
40 |
# Load pre-computed embeddings and similarities
|
41 |
-
with open(
|
42 |
self.sample_name_sentence_embeddings = pickle.load(f)
|
43 |
|
44 |
-
with open(SETENCE_SIMILARITY_FILE, "rb") as f:
|
45 |
-
self.sample_name_sentence_similarities = pickle.load(f)
|
46 |
-
|
47 |
# Load and process sample data
|
48 |
self.sampleData = SampleNameData()
|
49 |
self.sampleData.load_data_from_csv(SAMPLE_DATA_FILE)
|
50 |
self.sampleData.process_data()
|
51 |
-
|
52 |
-
# Create sentence clusters
|
53 |
-
self.sentence_clustering_lib = SentenceClusteringLib(self.sample_name_sentence_embeddings)
|
54 |
-
best_name_eps = 0.07
|
55 |
-
self.name_groups, _ = self.sentence_clustering_lib.create_sentence_cluster(best_name_eps)
|
56 |
-
|
57 |
-
self.sampleData._create_key_column(
|
58 |
-
COL_STANDARD_NAME_KEY, COL_STANDARD_SUBJECT, COL_STANDARD_NAME
|
59 |
-
)
|
60 |
-
self.sampleData.set_name_sentence_labels(self.name_groups)
|
61 |
self.sampleData.build_search_tree()
|
62 |
|
63 |
print("Models and data loaded successfully")
|
|
|
1 |
import pickle
|
2 |
from config import (
|
3 |
MODEL_NAME,
|
4 |
+
SENTENCE_EMBEDDING_FILE,
|
|
|
5 |
SAMPLE_DATA_FILE, SUBJECT_DATA_FILE
|
6 |
)
|
7 |
from sentence_transformer_lib.sentence_transformer_helper import SentenceTransformerHelper
|
8 |
from data_lib.subject_data import SubjectData
|
9 |
from data_lib.sample_name_data import SampleNameData
|
10 |
+
|
|
|
11 |
|
12 |
class SentenceTransformerService:
|
13 |
def __init__(self):
|
14 |
self.sentenceTransformerHelper = None
|
15 |
self.dic_standard_subject = None
|
16 |
self.sample_name_sentence_embeddings = None
|
|
|
17 |
self.sampleData = None
|
|
|
|
|
18 |
|
19 |
def load_model_data(self):
|
20 |
"""Load model and data only once at startup"""
|
|
|
33 |
self.dic_standard_subject = SubjectData.create_standard_subject_dic_from_file(SUBJECT_DATA_FILE)
|
34 |
|
35 |
# Load pre-computed embeddings and similarities
|
36 |
+
with open(SENTENCE_EMBEDDING_FILE, "rb") as f:
|
37 |
self.sample_name_sentence_embeddings = pickle.load(f)
|
38 |
|
|
|
|
|
|
|
39 |
# Load and process sample data
|
40 |
self.sampleData = SampleNameData()
|
41 |
self.sampleData.load_data_from_csv(SAMPLE_DATA_FILE)
|
42 |
self.sampleData.process_data()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
self.sampleData.build_search_tree()
|
44 |
|
45 |
print("Models and data loaded successfully")
|