mp3 / utils.py
hotdeem's picture
Upload 6 files
96df5f2 verified
import asyncio
import glob
import io
import os
import re
import time
import hashlib
from typing import Any, Dict, Generator
import uuid
from openai import OpenAI
import requests
from fishaudio import fishaudio_tts
from prompts import LANGUAGE_MODIFIER, LENGTH_MODIFIERS, PODCAST_INFO_PROMPT, QUESTION_MODIFIER, SUMMARY_INFO_PROMPT, SYSTEM_PROMPT, TONE_MODIFIER
import json
from pydub import AudioSegment
from fastapi import UploadFile
from PyPDF2 import PdfReader
from schema import PodcastInfo, ShortDialogue, Summary
from constants import (
AUDIO_CACHE_DIR,
FIREWORKS_API_KEY,
FIREWORKS_BASE_URL,
#FIREWORKS_MODEL_ID,
FIREWORKS_MAX_TOKENS,
FIREWORKS_TEMPERATURE,
GRADIO_CLEAR_CACHE_OLDER_THAN,
JINA_KEY,
SPEECH_KEY,
SPEECH_REGION,
WEBDAV_HOSTNAME,
WEBDAV_USERNAME,
WEBDAV_PASSWORD,
WEBDAV_AUDIO_PATH,
)
import azure.cognitiveservices.speech as speechsdk
# 上传相关
from webdav3.client import Client as WebDAVClient
import os.path
import aiohttp
WEBDAV_OPTIONS = {
'webdav_hostname': WEBDAV_HOSTNAME,
'webdav_login': WEBDAV_USERNAME,
'webdav_password': WEBDAV_PASSWORD,
'verbose': False
}
webdav_client = WebDAVClient(WEBDAV_OPTIONS)
async def upload_to_webdav(file_path: str, filename: str, podcast_title: str = None):
"""Upload file to WebDAV server with podcast title as filename"""
try:
# 生成上传文件名
upload_filename = filename
if podcast_title:
timestamp = time.strftime("%Y%m%d")
# 移除文件名中的非法字符
safe_title = re.sub(r'[\\/:*?"<>|]', '', podcast_title)
upload_filename = f"{safe_title}_{timestamp}.mp3"
remote_path = f"{WEBDAV_AUDIO_PATH}/{upload_filename}"
# 确保远程目录存在并上传
if not await asyncio.to_thread(webdav_client.check, WEBDAV_AUDIO_PATH):
await asyncio.to_thread(webdav_client.mkdir, WEBDAV_AUDIO_PATH)
await asyncio.to_thread(webdav_client.upload_sync, remote_path=remote_path, local_path=file_path)
print(f"File uploaded successfully to WebDAV as {upload_filename}")
except Exception as e:
print(f"WebDAV upload error: {e}")
# 上传相关
fw_client = OpenAI(base_url=FIREWORKS_BASE_URL, api_key=FIREWORKS_API_KEY)
def generate_dialogue(pdfFile, textInput, tone, duration, language, model_id, providerllm) -> Generator[str, None, None]:
modified_system_prompt = get_prompt(pdfFile, textInput, tone, duration, language)
if (modified_system_prompt == False):
yield json.dumps({
"type": "error",
"content": "Prompt is too long"
}) + "\n"
return
def clean_dialogue_format(text: str) -> str:
# 移除可能的 Markdown 代码块标记
if '```' in text:
matches = re.findall(r'```(?:.*?)\n([\s\S]*?)```', text, re.MULTILINE)
if matches:
text = matches[0]
# 清理和规范化对话格式
lines = text.strip().split('\n')
cleaned_lines = []
current_speaker = None
for line in lines:
line = line.strip()
if not line:
continue
# 检查对话格式
speaker_match = re.match(r'\*\*(Host|[^:]+)\*\*:\s*(.*)', line)
if speaker_match:
speaker, content = speaker_match.groups()
# 验证每句对话长度不超过100字符
if len(content) > 100:
# 将长对话分割成多个短句
sentences = re.split(r'([.!?。!?]+)', content)
current_content = ""
for i in range(0, len(sentences)-1, 2):
sentence = sentences[i] + (sentences[i+1] if i+1 < len(sentences) else "")
if current_content and len(current_content + sentence) > 100:
cleaned_lines.append(f"**{speaker}**: {current_content.strip()}")
current_content = sentence
else:
current_content += sentence
if current_content:
cleaned_lines.append(f"**{speaker}**: {current_content.strip()}")
else:
cleaned_lines.append(f"**{speaker}**: {content}")
current_speaker = speaker
elif current_speaker and line: # 处理可能的续行
cleaned_lines.append(f"**{current_speaker}**: {line}")
return '\n'.join(cleaned_lines)
full_response = ""
llm_stream = call_llm_stream(SYSTEM_PROMPT, modified_system_prompt, ShortDialogue, model_id, providerllm, isJSON=False)
for chunk in llm_stream:
# 实时清理每个块
cleaned_chunk = clean_dialogue_format(chunk)
yield json.dumps({"type": "chunk", "content": chunk}) + "\n"
full_response += chunk
# 最终清理完整响应
cleaned_response = clean_dialogue_format(full_response)
#print("测试生成的对话格式:") # 测试打印
#print(cleaned_response) # 测试打印
yield json.dumps({"type": "final", "content": full_response})
async def process_line(line, voice,provider):
if provider == 'fishaudio':
return await generate_podcast_audio(line['content'], voice)
elif provider == 'azure':
return await generate_podcast_audio_by_azure(line['content'], voice)
elif provider == 'openai':
return await generate_podcast_audio_by_openai(line['content'], voice)
return await generate_podcast_audio_by_openai(line['content'], voice)
async def generate_podcast_audio_by_openai(text: str, voice: str) -> str:
try:
# 使用已经配置好的 fw_client(与 LLM 相同的客户端)
response = await asyncio.to_thread(
fw_client.audio.speech.create,
model="tts-1", # OpenAI TTS 模型
voice=voice, # OpenAI 支持的声音:alloy, echo, fable, onyx, nova, shimmer
input=text
)
# 获取音频数据
audio_data = b''
for chunk in response.iter_bytes():
audio_data += chunk
# 使用 BytesIO 创建内存文件对象
audio_bytes = io.BytesIO(audio_data)
# 将二进制数据转换为 AudioSegment
audio_segment = AudioSegment.from_mp3(audio_bytes)
return audio_segment
except Exception as e:
print(f"Error in generate_podcast_audio_by_openai: {e}")
raise
async def generate_podcast_audio_by_azure(text: str, voice: str) -> str:
try:
speech_config = speechsdk.SpeechConfig(subscription=SPEECH_KEY, region=SPEECH_REGION)
speech_config.speech_synthesis_voice_name = voice
synthesizer = speechsdk.SpeechSynthesizer(speech_config=speech_config, audio_config=None)
future =await asyncio.to_thread(synthesizer.speak_text_async, text)
result = await asyncio.to_thread(future.get)
print("Speech synthesis completed")
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
print("Audio synthesized successfully")
audio_data = result.audio_data
audio_segment = AudioSegment.from_wav(io.BytesIO(audio_data))
return audio_segment
else:
print(f"Speech synthesis failed: {result.reason}")
if hasattr(result, 'cancellation_details'):
print(f"Cancellation details: {result.cancellation_details.reason}")
print(f"Cancellation error details: {result.cancellation_details.error_details}")
return None
except Exception as e:
print(f"Error in generate_podcast_audio: {e}")
raise
async def generate_podcast_audio(text: str, voice: str) -> str:
return await generate_podcast_audio_by_fish(text,voice)
async def generate_podcast_audio_by_fish(text: str, voice: str) -> str:
try:
return fishaudio_tts(text=text,reference_id=voice)
except Exception as e:
print(f"Error in generate_podcast_audio: {e}")
raise
async def process_lines_with_limit(lines, provider , host_voice, guest_voice, max_concurrency):
semaphore = asyncio.Semaphore(max_concurrency)
async def limited_process_line(line):
async with semaphore:
voice = host_voice if (line['speaker'] == '主持人' or line['speaker'] == 'Host') else guest_voice
return await process_line(line, voice , provider)
tasks = [limited_process_line(line) for line in lines]
results = await asyncio.gather(*tasks)
return results
async def combine_audio(task_status: Dict[str, Dict], task_id: str, text: str, language: str , provider:str,host_voice: str , guest_voice:str, podcast_info: PodcastInfo) -> Generator[str, None, None]:
try:
dialogue_regex = r'\*\*([\s\S]*?)\*\*[::]\s*([\s\S]*?)(?=\*\*|$)'
matches = re.findall(dialogue_regex, text, re.DOTALL)
lines = [
{
"speaker": match[0],
"content": match[1].strip(),
}
for match in matches
]
print("Starting audio generation")
# audio_segments = await asyncio.gather(
# *[process_line(line, host_voice if line['speaker'] == '主持人' else guest_voice) for line in lines]
# )
audio_segments = await process_lines_with_limit(lines,provider, host_voice, guest_voice, 10 if provider=='azure' else 5)
print("Audio generation completed")
# 合并音频
combined_audio = await asyncio.to_thread(sum, audio_segments)
print("Audio combined")
# 只在最后写入文件
unique_filename = f"{uuid.uuid4()}.mp3"
os.makedirs(AUDIO_CACHE_DIR, exist_ok=True)
file_path = os.path.join(AUDIO_CACHE_DIR, unique_filename)
# 异步导出音频文件
await asyncio.to_thread(combined_audio.export, file_path, format="mp3")
audio_url = f"/audio/{unique_filename}"
# 尝试上传文件,但不等待结果
# 使用podcast_info中的标题
asyncio.create_task(upload_to_webdav(file_path, unique_filename, podcast_info.title))
# 无论上传是否成功,都更新任务状态
task_status[task_id] = {
"status": "completed",
"audio_url": audio_url,
}
# 清理旧文件
for file in glob.glob(f"{AUDIO_CACHE_DIR}*.mp3"):
if (
os.path.isfile(file)
and time.time() - os.path.getmtime(file) > GRADIO_CLEAR_CACHE_OLDER_THAN
):
os.remove, file
clear_pdf_cache()
return audio_url
except Exception as e:
# 如果发生错误,更新状态为失败
task_status[task_id] = {"status": "failed", "error": str(e)}
def generate_podcast_summary(pdf_content: str, text: str, tone: str, length: str, language: str, model_id: str, providerllm: str) -> Generator[str, None, None]:
modified_system_prompt = get_prompt(pdf_content, text, '', '', '')
if (modified_system_prompt == False):
yield json.dumps({
"type": "error",
"content": "Prompt is too long"
}) + "\n"
return
stream = call_llm_stream(SUMMARY_INFO_PROMPT, modified_system_prompt, Summary, model_id, providerllm, False)
full_response = ""
for chunk in stream:
# 将每个 chunk 作为 JSON 字符串 yield
yield json.dumps({"type": "chunk", "content": chunk}) + "\n"
yield json.dumps({"type": "final", "content": full_response})
def generate_podcast_info(pdfContent: str, text: str, tone: str, length: str, language: str, model_id: str, providerllm: str) -> Generator[str, None, None]:
#print("2. 开始生成播客信息")
modified_system_prompt = get_prompt(pdfContent, text, '', '', '')
#print("3. 系统提示词已生成")
if (modified_system_prompt == False):
yield json.dumps({
"type": "error",
"content": "Prompt is too long"
}) + "\n"
return
full_response = ""
#print("4. 开始调用LLM")
for chunk in call_llm_stream(PODCAST_INFO_PROMPT, modified_system_prompt, PodcastInfo, model_id, providerllm):
full_response += chunk
print("播客信息完整响应:", full_response) # 打印完整响应
try:
# 清理和验证 JSON
def clean_podcast_json(raw_json_str: str) -> dict:
# 移除可能的 Markdown 代码块标记
if '```' in raw_json_str:
matches = re.findall(r'```(?:json)?([\s\S]*?)```', raw_json_str)
if matches:
raw_json_str = matches[0]
# 移除额外的空白字符
raw_json_str = raw_json_str.strip()
# 尝试解析 JSON
try:
result = json.loads(raw_json_str)
except json.JSONDecodeError:
# 如果解析失败,尝试提取最外层的花括号内容
match = re.search(r'{[\s\S]*}', raw_json_str)
if match:
try:
result = json.loads(match.group(0))
except:
raise ValueError("Invalid JSON format")
else:
raise ValueError("No valid JSON object found")
# 验证必需的字段
required_fields = {
"title": str,
"host_name": str
}
cleaned_result = {}
for field, field_type in required_fields.items():
if field not in result or not isinstance(result[field], field_type):
raise ValueError(f"Missing or invalid field: {field}")
cleaned_result[field] = result[field].strip()
return cleaned_result
result = clean_podcast_json(full_response)
#result = json.loads(full_response)
print("应该是正确的json:", result)
yield json.dumps({
"type": "podcast_info",
"content": result
}) + "\n"
except Exception as e:
#print("5. LLM调用失败:", str(e))
yield json.dumps({
"type": "error",
"content": f"An unexpected error occurred: {str(e)}"
}) + "\n"
def call_llm_stream(system_prompt: str, text: str, dialogue_format: Any, model_id: str, providerllm: str, isJSON: bool = True) -> Generator[str, None, None]:
"""Call the LLM with the given prompt and dialogue format, returning a stream of responses."""
request_params = {
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
],
#"model": FIREWORKS_MODEL_ID,
"model": model_id, # 使用传入的 model_id
"max_tokens": FIREWORKS_MAX_TOKENS,
"temperature": FIREWORKS_TEMPERATURE,
"stream": True # 启用流式输出
}
# 如果需要 JSON 响应,添加 response_format 参数
if isJSON:
request_params["response_format"] = {
"type": "json_object",
"schema": dialogue_format.model_json_schema(),
}
stream = fw_client.chat.completions.create(**request_params)
full_response = ""
for chunk in stream:
if chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
full_response += content
yield content
# 在流结束时,尝试解析完整的 JSON 响应
# try:
# parsed_response = json.loads(full_response)
# yield json.dumps({"type": "final", "content": parsed_response})
# except json.JSONDecodeError:
# yield json.dumps({"type": "error", "content": "Failed to parse JSON response"})
def call_llm(system_prompt: str, text: str, dialogue_format: Any, model_id: str, providerllm: str) -> Any:
"""Call the LLM with the given prompt and dialogue format."""
response = fw_client.chat.completions.create(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
],
#model=FIREWORKS_MODEL_ID,
model=model_id, # 使用传入的 model_id
max_tokens=FIREWORKS_MAX_TOKENS,
temperature=FIREWORKS_TEMPERATURE,
response_format={
"type": "json_object",
"schema": dialogue_format.model_json_schema(),
},
)
return response
pdf_cache = {}
def clear_pdf_cache():
global pdf_cache
pdf_cache.clear()
def get_link_text(url: str):
# """ 通过jina.ai 抓取url内容 """
# url = f"https://r.jina.ai/{url}"
# headers = {}
# headers['Authorization'] = 'Bearer ' + JINA_KEY
# headers['Accept'] = 'application/json'
# headers['X-Return-Format'] = 'text'
# response = requests.get(url, headers=headers)
# return response.json()['data']
""" 通过magic-html-api抓取url内容 """
api_url = f"https://magic-html-api.vercel.app/api/extract"
params = {
"url": url,
"output_format": "markdown"
}
response = requests.get(api_url, params=params)
result = response.json()
if result["success"]:
return result["content"]
else:
raise Exception("Failed to extract content from URL")
async def get_pdf_text(pdf_file: UploadFile):
text = ""
try:
# 读取上传文件的内容
contents = await pdf_file.read()
file_hash = hashlib.md5(contents).hexdigest()
if file_hash in pdf_cache:
return pdf_cache[file_hash]
# 使用 BytesIO 创建一个内存中的文件对象
pdf_file_obj = io.BytesIO(contents)
# 使用 PdfReader 读取 PDF 内容
pdf_reader = PdfReader(pdf_file_obj)
# 提取所有页面的文本
text = "\n\n".join([page.extract_text() for page in pdf_reader.pages])
# 重置文件指针,以防后续还需要读取文件
await pdf_file.seek(0)
return text
except Exception as e:
return {"error": str(e)}
def get_prompt(pdfContent: str, text: str, tone: str, length: str, language: str):
modified_system_prompt = ""
new_text = pdfContent +text
if pdfContent:
modified_system_prompt += f"\n\n{QUESTION_MODIFIER} {new_text}"
if tone:
modified_system_prompt += f"\n\n{TONE_MODIFIER} {tone}."
if length:
modified_system_prompt += f"\n\n{LENGTH_MODIFIERS[length]}"
if language:
modified_system_prompt += f"\n\n{LANGUAGE_MODIFIER} {language}."
return modified_system_prompt