Spaces:
Running
Running
优化应用初始化逻辑,使用异步上下文管理器处理生命周期;改进模型加载机制,添加线程锁以确保线程安全;更新 Gunicorn 配置以提高性能和稳定性
Browse files- app.py +27 -19
- blkeras.py +38 -23
- gunicorn.conf.py +17 -1
app.py
CHANGED
|
@@ -1,16 +1,37 @@
|
|
| 1 |
import os
|
| 2 |
from fastapi import FastAPI
|
| 3 |
from pydantic import BaseModel
|
| 4 |
-
from fastapi.middleware.wsgi import WSGIMiddleware
|
| 5 |
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
| 7 |
-
|
| 8 |
-
from
|
| 9 |
|
| 10 |
from RequestModel import PredictRequest
|
| 11 |
-
from us_stock import fetch_symbols
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# 添加 CORS 中间件和限流配置
|
| 16 |
app.add_middleware(
|
|
@@ -55,25 +76,12 @@ async def api_bbb(request: TextRequest):
|
|
| 55 |
result = request.text + 'bbb'
|
| 56 |
return {"result": result}
|
| 57 |
|
| 58 |
-
|
| 59 |
-
@app.on_event("startup")
|
| 60 |
-
async def initialize_symbols():
|
| 61 |
-
# 在 FastAPI 启动时初始化变量
|
| 62 |
-
await fetch_symbols()
|
| 63 |
-
|
| 64 |
# 优化预测路由
|
| 65 |
@app.post("/api/predict")
|
| 66 |
async def predict(request: PredictRequest):
|
| 67 |
from blkeras import predict
|
| 68 |
-
|
| 69 |
try:
|
| 70 |
-
|
| 71 |
-
import asyncio
|
| 72 |
-
result = await asyncio.to_thread(
|
| 73 |
-
predict,
|
| 74 |
-
request.text,
|
| 75 |
-
request.stock_codes
|
| 76 |
-
)
|
| 77 |
return result
|
| 78 |
except Exception as e:
|
| 79 |
return []
|
|
|
|
| 1 |
import os
|
| 2 |
from fastapi import FastAPI
|
| 3 |
from pydantic import BaseModel
|
|
|
|
| 4 |
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
| 6 |
+
import asyncio
|
| 7 |
+
from contextlib import asynccontextmanager
|
| 8 |
|
| 9 |
from RequestModel import PredictRequest
|
|
|
|
| 10 |
|
| 11 |
+
# 全局变量,用于跟踪初始化状态
|
| 12 |
+
is_initialized = False
|
| 13 |
+
initialization_lock = asyncio.Lock()
|
| 14 |
+
|
| 15 |
+
@asynccontextmanager
|
| 16 |
+
async def lifespan(app: FastAPI):
|
| 17 |
+
# 启动时运行
|
| 18 |
+
global is_initialized
|
| 19 |
+
async with initialization_lock:
|
| 20 |
+
if not is_initialized:
|
| 21 |
+
await initialize_application()
|
| 22 |
+
is_initialized = True
|
| 23 |
+
yield
|
| 24 |
+
# 关闭时运行
|
| 25 |
+
# cleanup_code_here()
|
| 26 |
+
|
| 27 |
+
async def initialize_application():
|
| 28 |
+
# 在这里进行所有需要的初始化
|
| 29 |
+
from us_stock import fetch_symbols
|
| 30 |
+
|
| 31 |
+
await fetch_symbols()
|
| 32 |
+
# 其他初始化代码...
|
| 33 |
+
|
| 34 |
+
app = FastAPI(lifespan=lifespan)
|
| 35 |
|
| 36 |
# 添加 CORS 中间件和限流配置
|
| 37 |
app.add_middleware(
|
|
|
|
| 76 |
result = request.text + 'bbb'
|
| 77 |
return {"result": result}
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
# 优化预测路由
|
| 80 |
@app.post("/api/predict")
|
| 81 |
async def predict(request: PredictRequest):
|
| 82 |
from blkeras import predict
|
|
|
|
| 83 |
try:
|
| 84 |
+
result = await asyncio.to_thread(predict, request.text, request.stock_codes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
return result
|
| 86 |
except Exception as e:
|
| 87 |
return []
|
blkeras.py
CHANGED
|
@@ -27,35 +27,48 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
| 27 |
# 设置环境变量,指定 Hugging Face 缓存路径
|
| 28 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# 加载模型
|
| 31 |
model = None
|
| 32 |
-
if model is None:
|
| 33 |
-
# 从环境变量中获取 Hugging Face token
|
| 34 |
-
hf_token = os.environ.get("HF_Token")
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
if
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
|
|
@@ -106,6 +119,7 @@ def predict(text: str, stock_codes: list):
|
|
| 106 |
|
| 107 |
print(f"Input Text Length: {len(text)}, Start with: {text[:200] if len(text) > 200 else text}")
|
| 108 |
print("Input stock codes:", stock_codes)
|
|
|
|
| 109 |
|
| 110 |
start_time = datetime.now()
|
| 111 |
input_text = text
|
|
@@ -230,6 +244,7 @@ def predict(text: str, stock_codes: list):
|
|
| 230 |
# print(f"模型所需的输入层 {layer.name}, 形状: {layer.shape}")
|
| 231 |
|
| 232 |
# 使用模型进行预测
|
|
|
|
| 233 |
predictions = model.predict(features)
|
| 234 |
|
| 235 |
# 生成伪精准度值
|
|
|
|
| 27 |
# 设置环境变量,指定 Hugging Face 缓存路径
|
| 28 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
| 29 |
|
| 30 |
+
import threading
|
| 31 |
+
|
| 32 |
+
# 添加线程锁
|
| 33 |
+
model_lock = threading.Lock()
|
| 34 |
+
model_initialized = False
|
| 35 |
+
|
| 36 |
# 加载模型
|
| 37 |
model = None
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
def get_model():
|
| 40 |
+
global model, model_initialized
|
| 41 |
+
if not model_initialized:
|
| 42 |
+
with model_lock:
|
| 43 |
+
if not model_initialized: # 双重检查锁定
|
| 44 |
+
# 从环境变量中获取 Hugging Face token
|
| 45 |
+
hf_token = os.environ.get("HF_Token")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# 使用 Hugging Face API token 登录 (确保只读权限)
|
| 49 |
+
if hf_token:
|
| 50 |
+
login(token=hf_token)
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError("Hugging Face token not found in environment variables.")
|
| 53 |
|
| 54 |
+
# 下载模型到本地
|
| 55 |
+
model_path = hf_hub_download(repo_id="parkerjj/BuckLake-Stock-Model",
|
| 56 |
+
filename="stock_prediction_model_1118_final.keras",
|
| 57 |
+
use_auth_token=hf_token)
|
| 58 |
|
| 59 |
+
# 使用 Keras 加载模型
|
| 60 |
+
os.environ["KERAS_BACKEND"] = "jax"
|
| 61 |
+
print(f"Loading saved model from {model_path}...")
|
| 62 |
+
from model_build import TransformerEncoder, ExpandDimension, ConcatenateTimesteps
|
| 63 |
+
model = keras.saving.load_model(model_path, custom_objects={
|
| 64 |
+
"TransformerEncoder": TransformerEncoder,
|
| 65 |
+
"ExpandDimension": ExpandDimension,
|
| 66 |
+
"ConcatenateTimesteps": ConcatenateTimesteps
|
| 67 |
+
})
|
| 68 |
|
| 69 |
+
model.summary()
|
| 70 |
+
model_initialized = True
|
| 71 |
+
return model
|
| 72 |
|
| 73 |
|
| 74 |
|
|
|
|
| 119 |
|
| 120 |
print(f"Input Text Length: {len(text)}, Start with: {text[:200] if len(text) > 200 else text}")
|
| 121 |
print("Input stock codes:", stock_codes)
|
| 122 |
+
print("Current Time:", datetime.now())
|
| 123 |
|
| 124 |
start_time = datetime.now()
|
| 125 |
input_text = text
|
|
|
|
| 244 |
# print(f"模型所需的输入层 {layer.name}, 形状: {layer.shape}")
|
| 245 |
|
| 246 |
# 使用模型进行预测
|
| 247 |
+
model = get_model()
|
| 248 |
predictions = model.predict(features)
|
| 249 |
|
| 250 |
# 生成伪精准度值
|
gunicorn.conf.py
CHANGED
|
@@ -11,6 +11,9 @@ workers = multiprocessing.cpu_count() + 1
|
|
| 11 |
# 设置为2,增加并发处理能力
|
| 12 |
threads = 2
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
# 工作方式
|
| 15 |
worker_class = "uvicorn.workers.UvicornWorker"
|
| 16 |
|
|
@@ -27,7 +30,20 @@ worker_connections = 2000
|
|
| 27 |
|
| 28 |
# 工作模式
|
| 29 |
worker_tmp_dir = "/dev/shm" # 使用内存文件系统提高性能
|
| 30 |
-
preload_app =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
# 进程名称前缀
|
| 33 |
proc_name = 'gunicorn_fastapi'
|
|
|
|
| 11 |
# 设置为2,增加并发处理能力
|
| 12 |
threads = 2
|
| 13 |
|
| 14 |
+
# 请求超时时间
|
| 15 |
+
timeout = 600
|
| 16 |
+
|
| 17 |
# 工作方式
|
| 18 |
worker_class = "uvicorn.workers.UvicornWorker"
|
| 19 |
|
|
|
|
| 30 |
|
| 31 |
# 工作模式
|
| 32 |
worker_tmp_dir = "/dev/shm" # 使用内存文件系统提高性能
|
| 33 |
+
preload_app = False # 修改为 False,避免重复加载
|
| 34 |
+
|
| 35 |
+
# 添加新的配置
|
| 36 |
+
reload = False # 禁用自动重载
|
| 37 |
+
daemon = False # 非守护进程模式运行
|
| 38 |
+
|
| 39 |
+
# 添加应用初始化钩子
|
| 40 |
+
def when_ready(server):
|
| 41 |
+
# 当 Gunicorn 准备好时执行
|
| 42 |
+
server.log.info("Server is ready. Doing nothing.")
|
| 43 |
+
|
| 44 |
+
def post_fork(server, worker):
|
| 45 |
+
# 当 worker 进程被 fork 后执行
|
| 46 |
+
server.log.info(f"Worker spawned (pid: {worker.pid})")
|
| 47 |
|
| 48 |
# 进程名称前缀
|
| 49 |
proc_name = 'gunicorn_fastapi'
|