Spaces:
Running
Running
change model
Browse files- config.py +2 -0
- data/sampleDataInput.csv +0 -0
- requirements.txt +0 -0
- routes/predict.py +14 -2
- services/sentence_transformer_service.py +22 -8
config.py
CHANGED
@@ -19,3 +19,5 @@ SENTENCE_EMBEDDING_FILE = os.path.join(
|
|
19 |
DATA_DIR,
|
20 |
"anchor_name_sentence_sentence_embeddings(cl-nagoya-sup-simcse-ja-for-standard-name-v0_9_10).pkl",
|
21 |
)
|
|
|
|
|
|
19 |
DATA_DIR,
|
20 |
"anchor_name_sentence_sentence_embeddings(cl-nagoya-sup-simcse-ja-for-standard-name-v0_9_10).pkl",
|
21 |
)
|
22 |
+
MODEL_TYPE = "openvino"
|
23 |
+
DEVICE_TYPE = "cpu"
|
data/sampleDataInput.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|
routes/predict.py
CHANGED
@@ -8,6 +8,7 @@ from auth import get_current_user
|
|
8 |
from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service
|
9 |
from data_lib.input_name_data import InputNameData
|
10 |
from data_lib.base_name_data import COL_NAME_SENTENCE
|
|
|
11 |
from mapping_lib.name_mapper import NameMapper
|
12 |
from config import UPLOAD_DIR, OUTPUT_DIR
|
13 |
from models import (
|
@@ -48,9 +49,16 @@ async def predict(
|
|
48 |
# Process input data
|
49 |
start_time = time.time()
|
50 |
try:
|
51 |
-
inputData = InputNameData(
|
52 |
inputData.load_data_from_csv(input_file_path)
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
except Exception as e:
|
55 |
print(f"Error processing input data: {e}")
|
56 |
raise HTTPException(status_code=500, detail=str(e))
|
@@ -107,9 +115,13 @@ async def create_embeddings(
|
|
107 |
Create embeddings for a list of input sentences (requires authentication)
|
108 |
"""
|
109 |
try:
|
|
|
110 |
embeddings = sentence_service.sentenceTransformerHelper.create_embeddings(
|
111 |
request.sentences
|
112 |
)
|
|
|
|
|
|
|
113 |
# Convert numpy array to list for JSON serialization
|
114 |
embeddings_list = embeddings.tolist()
|
115 |
return {"embeddings": embeddings_list}
|
|
|
8 |
from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service
|
9 |
from data_lib.input_name_data import InputNameData
|
10 |
from data_lib.base_name_data import COL_NAME_SENTENCE
|
11 |
+
from mapping_lib.subject_mapper import SubjectMapper
|
12 |
from mapping_lib.name_mapper import NameMapper
|
13 |
from config import UPLOAD_DIR, OUTPUT_DIR
|
14 |
from models import (
|
|
|
49 |
# Process input data
|
50 |
start_time = time.time()
|
51 |
try:
|
52 |
+
inputData = InputNameData()
|
53 |
inputData.load_data_from_csv(input_file_path)
|
54 |
+
except Exception as e:
|
55 |
+
print(f"Error processing input data: {e}")
|
56 |
+
raise HTTPException(status_code=500, detail=str(e))
|
57 |
+
try:
|
58 |
+
subject_mapper = SubjectMapper(sentence_service.sentenceTransformerHelper, sentence_service.dic_standard_subject)
|
59 |
+
dic_subject_map = subject_mapper.map_standard_subjects(inputData.dataframe)
|
60 |
+
inputData.dic_standard_subject = dic_subject_map
|
61 |
+
inputData.process_data()
|
62 |
except Exception as e:
|
63 |
print(f"Error processing input data: {e}")
|
64 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
115 |
Create embeddings for a list of input sentences (requires authentication)
|
116 |
"""
|
117 |
try:
|
118 |
+
start_time = time.time()
|
119 |
embeddings = sentence_service.sentenceTransformerHelper.create_embeddings(
|
120 |
request.sentences
|
121 |
)
|
122 |
+
end_time = time.time()
|
123 |
+
execution_time = end_time - start_time
|
124 |
+
print(f"Execution time: {execution_time} seconds")
|
125 |
# Convert numpy array to list for JSON serialization
|
126 |
embeddings_list = embeddings.tolist()
|
127 |
return {"embeddings": embeddings_list}
|
services/sentence_transformer_service.py
CHANGED
@@ -1,6 +1,22 @@
|
|
1 |
import pickle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from config import (
|
3 |
-
MODEL_NAME,
|
4 |
SENTENCE_EMBEDDING_FILE,
|
5 |
STANDARD_NAME_MAP_DATA_FILE, SUBJECT_DATA_FILE
|
6 |
)
|
@@ -24,10 +40,8 @@ class SentenceTransformerService:
|
|
24 |
|
25 |
print("Loading models and data...")
|
26 |
# Load sentence transformer model
|
27 |
-
self.sentenceTransformerHelper = SentenceTransformerHelper(
|
28 |
-
|
29 |
-
)
|
30 |
-
self.sentenceTransformerHelper.load_model_by_name(MODEL_NAME)
|
31 |
|
32 |
# Load standard subject dictionary
|
33 |
self.dic_standard_subject = SubjectData.create_standard_subject_dic_from_file(SUBJECT_DATA_FILE)
|
@@ -37,10 +51,10 @@ class SentenceTransformerService:
|
|
37 |
self.anchor_name_sentence_embeddings = pickle.load(f)
|
38 |
|
39 |
# Load and process sample data
|
40 |
-
self.standardNameMapData = StandardNameMapData()
|
41 |
self.standardNameMapData.load_data_from_csv(STANDARD_NAME_MAP_DATA_FILE)
|
42 |
-
self.standardNameMapData.process_data(
|
43 |
-
|
44 |
print("Models and data loaded successfully")
|
45 |
|
46 |
# Global instance (singleton)
|
|
|
1 |
import pickle
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
# Filter NumPy array implementation warnings specifically
|
5 |
+
warnings.filterwarnings(
|
6 |
+
"ignore",
|
7 |
+
message=".*array.*implementation doesn't accept a copy keyword.*",
|
8 |
+
category=DeprecationWarning
|
9 |
+
)
|
10 |
+
|
11 |
+
# Or alternatively, target the exact warning message:
|
12 |
+
warnings.filterwarnings(
|
13 |
+
"ignore",
|
14 |
+
message=".*NumPy will pass `copy` to the `__array__` special method.*",
|
15 |
+
category=DeprecationWarning
|
16 |
+
)
|
17 |
+
|
18 |
from config import (
|
19 |
+
MODEL_NAME, MODEL_TYPE, DEVICE_TYPE,
|
20 |
SENTENCE_EMBEDDING_FILE,
|
21 |
STANDARD_NAME_MAP_DATA_FILE, SUBJECT_DATA_FILE
|
22 |
)
|
|
|
40 |
|
41 |
print("Loading models and data...")
|
42 |
# Load sentence transformer model
|
43 |
+
self.sentenceTransformerHelper = SentenceTransformerHelper(model_name=MODEL_NAME, model_type=MODEL_TYPE)
|
44 |
+
print(f"Loading model {MODEL_NAME} with type {MODEL_TYPE}")
|
|
|
|
|
45 |
|
46 |
# Load standard subject dictionary
|
47 |
self.dic_standard_subject = SubjectData.create_standard_subject_dic_from_file(SUBJECT_DATA_FILE)
|
|
|
51 |
self.anchor_name_sentence_embeddings = pickle.load(f)
|
52 |
|
53 |
# Load and process sample data
|
54 |
+
self.standardNameMapData = StandardNameMapData(self.anchor_name_sentence_embeddings)
|
55 |
self.standardNameMapData.load_data_from_csv(STANDARD_NAME_MAP_DATA_FILE)
|
56 |
+
self.standardNameMapData.process_data()
|
57 |
+
|
58 |
print("Models and data loaded successfully")
|
59 |
|
60 |
# Global instance (singleton)
|