Spaces:
Running
Running
优化预测函数的输入文本打印逻辑,增加文本长度信息;改进长文本处理函数,考虑特殊标记长度以保持句子完整性
Browse files- app.py +2 -2
- preprocess.py +9 -3
app.py
CHANGED
|
@@ -52,8 +52,8 @@ async def predict(request: PredictRequest):
|
|
| 52 |
try:
|
| 53 |
input_text = request.text # FastAPI 会自动解析为 PredictRequest 对象
|
| 54 |
affected_stock_codes = request.stock_codes
|
| 55 |
-
print("Input
|
| 56 |
-
print("
|
| 57 |
return predict(input_text, affected_stock_codes)
|
| 58 |
except Exception as e:
|
| 59 |
return {"error": str(e)}
|
|
|
|
| 52 |
try:
|
| 53 |
input_text = request.text # FastAPI 会自动解析为 PredictRequest 对象
|
| 54 |
affected_stock_codes = request.stock_codes
|
| 55 |
+
print(f"Input Text Length: {len(input_text)}, Start with: {input_text[:200] if len(input_text) > 200 else input_text}")
|
| 56 |
+
print("Input stock codes:", affected_stock_codes)
|
| 57 |
return predict(input_text, affected_stock_codes)
|
| 58 |
except Exception as e:
|
| 59 |
return {"error": str(e)}
|
preprocess.py
CHANGED
|
@@ -10,6 +10,7 @@ import pandas as pd
|
|
| 10 |
import time
|
| 11 |
|
| 12 |
# 如果使用 spaCy 进行 NLP 处理
|
|
|
|
| 13 |
import spacy
|
| 14 |
|
| 15 |
# 如果使用某种情感分析工具,比如 Hugging Face 的模型
|
|
@@ -225,7 +226,7 @@ def get_document_vector(words, model = word2vec_model):
|
|
| 225 |
# 函数:获取情感得分
|
| 226 |
def process_long_text(text, tokenizer, max_length=512):
|
| 227 |
"""
|
| 228 |
-
|
| 229 |
"""
|
| 230 |
import nltk
|
| 231 |
try:
|
|
@@ -239,15 +240,19 @@ def process_long_text(text, tokenizer, max_length=512):
|
|
| 239 |
nltk.download('punkt_tab')
|
| 240 |
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
sentences = nltk.sent_tokenize(text)
|
| 243 |
segments = []
|
| 244 |
current_segment = ""
|
| 245 |
|
| 246 |
for sentence in sentences:
|
| 247 |
-
print(f"Processing sentence: {sentence}")
|
| 248 |
# 检查添加当前句子后是否会超过最大长度
|
| 249 |
test_segment = current_segment + " " + sentence if current_segment else sentence
|
| 250 |
-
if len(tokenizer.tokenize(test_segment)) >
|
| 251 |
if current_segment:
|
| 252 |
segments.append(current_segment.strip())
|
| 253 |
current_segment = sentence
|
|
@@ -340,6 +345,7 @@ def get_sentiment_score(text):
|
|
| 340 |
return 0.0
|
| 341 |
|
| 342 |
|
|
|
|
| 343 |
def get_stock_info(stock_code: str, history_days=30):
|
| 344 |
# 获取股票代码和新闻日期
|
| 345 |
|
|
|
|
| 10 |
import time
|
| 11 |
|
| 12 |
# 如果使用 spaCy 进行 NLP 处理
|
| 13 |
+
from regex import R
|
| 14 |
import spacy
|
| 15 |
|
| 16 |
# 如果使用某种情感分析工具,比如 Hugging Face 的模型
|
|
|
|
| 226 |
# 函数:获取情感得分
|
| 227 |
def process_long_text(text, tokenizer, max_length=512):
|
| 228 |
"""
|
| 229 |
+
将长文本分段并保持句子完整性,同时考虑特殊标记的长度
|
| 230 |
"""
|
| 231 |
import nltk
|
| 232 |
try:
|
|
|
|
| 240 |
nltk.download('punkt_tab')
|
| 241 |
|
| 242 |
|
| 243 |
+
# 计算特殊标记占用的长度(CLS, SEP等)
|
| 244 |
+
special_tokens_count = tokenizer.num_special_tokens_to_add()
|
| 245 |
+
# 实际可用于文本的最大长度
|
| 246 |
+
effective_max_length = max_length - special_tokens_count
|
| 247 |
+
|
| 248 |
sentences = nltk.sent_tokenize(text)
|
| 249 |
segments = []
|
| 250 |
current_segment = ""
|
| 251 |
|
| 252 |
for sentence in sentences:
|
|
|
|
| 253 |
# 检查添加当前句子后是否会超过最大长度
|
| 254 |
test_segment = current_segment + " " + sentence if current_segment else sentence
|
| 255 |
+
if len(tokenizer.tokenize(test_segment)) > effective_max_length:
|
| 256 |
if current_segment:
|
| 257 |
segments.append(current_segment.strip())
|
| 258 |
current_segment = sentence
|
|
|
|
| 345 |
return 0.0
|
| 346 |
|
| 347 |
|
| 348 |
+
|
| 349 |
def get_stock_info(stock_code: str, history_days=30):
|
| 350 |
# 获取股票代码和新闻日期
|
| 351 |
|