|
import asyncio
|
|
import glob
|
|
import io
|
|
import os
|
|
import re
|
|
import time
|
|
import hashlib
|
|
from typing import Any, Dict, Generator
|
|
import uuid
|
|
from openai import OpenAI
|
|
import requests
|
|
from fishaudio import fishaudio_tts
|
|
from prompts import LANGUAGE_MODIFIER, LENGTH_MODIFIERS, PODCAST_INFO_PROMPT, QUESTION_MODIFIER, SUMMARY_INFO_PROMPT, SYSTEM_PROMPT, TONE_MODIFIER
|
|
import json
|
|
from pydub import AudioSegment
|
|
from fastapi import UploadFile
|
|
from PyPDF2 import PdfReader
|
|
from schema import PodcastInfo, ShortDialogue, Summary
|
|
from constants import (
|
|
AUDIO_CACHE_DIR,
|
|
FIREWORKS_API_KEY,
|
|
FIREWORKS_BASE_URL,
|
|
|
|
FIREWORKS_MAX_TOKENS,
|
|
FIREWORKS_TEMPERATURE,
|
|
GRADIO_CLEAR_CACHE_OLDER_THAN,
|
|
JINA_KEY,
|
|
SPEECH_KEY,
|
|
SPEECH_REGION,
|
|
|
|
WEBDAV_HOSTNAME,
|
|
WEBDAV_USERNAME,
|
|
WEBDAV_PASSWORD,
|
|
WEBDAV_AUDIO_PATH,
|
|
)
|
|
import azure.cognitiveservices.speech as speechsdk
|
|
|
|
|
|
from webdav3.client import Client as WebDAVClient
|
|
import os.path
|
|
import aiohttp
|
|
|
|
WEBDAV_OPTIONS = {
|
|
'webdav_hostname': WEBDAV_HOSTNAME,
|
|
'webdav_login': WEBDAV_USERNAME,
|
|
'webdav_password': WEBDAV_PASSWORD,
|
|
'verbose': False
|
|
}
|
|
|
|
webdav_client = WebDAVClient(WEBDAV_OPTIONS)
|
|
|
|
async def upload_to_webdav(file_path: str, filename: str, podcast_title: str = None):
|
|
"""Upload file to WebDAV server with podcast title as filename"""
|
|
try:
|
|
|
|
upload_filename = filename
|
|
if podcast_title:
|
|
timestamp = time.strftime("%Y%m%d")
|
|
|
|
safe_title = re.sub(r'[\\/:*?"<>|]', '', podcast_title)
|
|
upload_filename = f"{safe_title}_{timestamp}.mp3"
|
|
|
|
remote_path = f"{WEBDAV_AUDIO_PATH}/{upload_filename}"
|
|
|
|
|
|
if not await asyncio.to_thread(webdav_client.check, WEBDAV_AUDIO_PATH):
|
|
await asyncio.to_thread(webdav_client.mkdir, WEBDAV_AUDIO_PATH)
|
|
|
|
await asyncio.to_thread(webdav_client.upload_sync, remote_path=remote_path, local_path=file_path)
|
|
print(f"File uploaded successfully to WebDAV as {upload_filename}")
|
|
|
|
except Exception as e:
|
|
print(f"WebDAV upload error: {e}")
|
|
|
|
|
|
fw_client = OpenAI(base_url=FIREWORKS_BASE_URL, api_key=FIREWORKS_API_KEY)
|
|
|
|
|
|
|
|
def generate_dialogue(pdfFile, textInput, tone, duration, language, model_id, providerllm) -> Generator[str, None, None]:
|
|
modified_system_prompt = get_prompt(pdfFile, textInput, tone, duration, language)
|
|
if (modified_system_prompt == False):
|
|
yield json.dumps({
|
|
"type": "error",
|
|
"content": "Prompt is too long"
|
|
}) + "\n"
|
|
return
|
|
|
|
def clean_dialogue_format(text: str) -> str:
|
|
|
|
if '```' in text:
|
|
matches = re.findall(r'```(?:.*?)\n([\s\S]*?)```', text, re.MULTILINE)
|
|
if matches:
|
|
text = matches[0]
|
|
|
|
|
|
lines = text.strip().split('\n')
|
|
cleaned_lines = []
|
|
current_speaker = None
|
|
|
|
for line in lines:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
|
|
speaker_match = re.match(r'\*\*(Host|[^:]+)\*\*:\s*(.*)', line)
|
|
if speaker_match:
|
|
speaker, content = speaker_match.groups()
|
|
|
|
if len(content) > 100:
|
|
|
|
sentences = re.split(r'([.!?。!?]+)', content)
|
|
current_content = ""
|
|
for i in range(0, len(sentences)-1, 2):
|
|
sentence = sentences[i] + (sentences[i+1] if i+1 < len(sentences) else "")
|
|
if current_content and len(current_content + sentence) > 100:
|
|
cleaned_lines.append(f"**{speaker}**: {current_content.strip()}")
|
|
current_content = sentence
|
|
else:
|
|
current_content += sentence
|
|
if current_content:
|
|
cleaned_lines.append(f"**{speaker}**: {current_content.strip()}")
|
|
else:
|
|
cleaned_lines.append(f"**{speaker}**: {content}")
|
|
current_speaker = speaker
|
|
elif current_speaker and line:
|
|
cleaned_lines.append(f"**{current_speaker}**: {line}")
|
|
|
|
return '\n'.join(cleaned_lines)
|
|
|
|
full_response = ""
|
|
llm_stream = call_llm_stream(SYSTEM_PROMPT, modified_system_prompt, ShortDialogue, model_id, providerllm, isJSON=False)
|
|
|
|
for chunk in llm_stream:
|
|
|
|
cleaned_chunk = clean_dialogue_format(chunk)
|
|
|
|
yield json.dumps({"type": "chunk", "content": chunk}) + "\n"
|
|
full_response += chunk
|
|
|
|
|
|
cleaned_response = clean_dialogue_format(full_response)
|
|
|
|
|
|
|
|
yield json.dumps({"type": "final", "content": full_response})
|
|
|
|
async def process_line(line, voice,provider):
|
|
if provider == 'fishaudio':
|
|
return await generate_podcast_audio(line['content'], voice)
|
|
elif provider == 'azure':
|
|
return await generate_podcast_audio_by_azure(line['content'], voice)
|
|
elif provider == 'openai':
|
|
return await generate_podcast_audio_by_openai(line['content'], voice)
|
|
return await generate_podcast_audio_by_openai(line['content'], voice)
|
|
|
|
async def generate_podcast_audio_by_openai(text: str, voice: str) -> str:
|
|
try:
|
|
|
|
response = await asyncio.to_thread(
|
|
fw_client.audio.speech.create,
|
|
model="tts-1",
|
|
voice=voice,
|
|
input=text
|
|
)
|
|
|
|
|
|
audio_data = b''
|
|
for chunk in response.iter_bytes():
|
|
audio_data += chunk
|
|
|
|
|
|
audio_bytes = io.BytesIO(audio_data)
|
|
|
|
|
|
audio_segment = AudioSegment.from_mp3(audio_bytes)
|
|
return audio_segment
|
|
|
|
except Exception as e:
|
|
print(f"Error in generate_podcast_audio_by_openai: {e}")
|
|
raise
|
|
|
|
async def generate_podcast_audio_by_azure(text: str, voice: str) -> str:
|
|
try:
|
|
speech_config = speechsdk.SpeechConfig(subscription=SPEECH_KEY, region=SPEECH_REGION)
|
|
speech_config.speech_synthesis_voice_name = voice
|
|
|
|
synthesizer = speechsdk.SpeechSynthesizer(speech_config=speech_config, audio_config=None)
|
|
future =await asyncio.to_thread(synthesizer.speak_text_async, text)
|
|
|
|
result = await asyncio.to_thread(future.get)
|
|
|
|
print("Speech synthesis completed")
|
|
|
|
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
|
|
print("Audio synthesized successfully")
|
|
audio_data = result.audio_data
|
|
audio_segment = AudioSegment.from_wav(io.BytesIO(audio_data))
|
|
return audio_segment
|
|
else:
|
|
print(f"Speech synthesis failed: {result.reason}")
|
|
if hasattr(result, 'cancellation_details'):
|
|
print(f"Cancellation details: {result.cancellation_details.reason}")
|
|
print(f"Cancellation error details: {result.cancellation_details.error_details}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
print(f"Error in generate_podcast_audio: {e}")
|
|
raise
|
|
|
|
async def generate_podcast_audio(text: str, voice: str) -> str:
|
|
return await generate_podcast_audio_by_fish(text,voice)
|
|
|
|
async def generate_podcast_audio_by_fish(text: str, voice: str) -> str:
|
|
try:
|
|
return fishaudio_tts(text=text,reference_id=voice)
|
|
except Exception as e:
|
|
print(f"Error in generate_podcast_audio: {e}")
|
|
raise
|
|
async def process_lines_with_limit(lines, provider , host_voice, guest_voice, max_concurrency):
|
|
semaphore = asyncio.Semaphore(max_concurrency)
|
|
|
|
async def limited_process_line(line):
|
|
async with semaphore:
|
|
voice = host_voice if (line['speaker'] == '主持人' or line['speaker'] == 'Host') else guest_voice
|
|
return await process_line(line, voice , provider)
|
|
|
|
tasks = [limited_process_line(line) for line in lines]
|
|
results = await asyncio.gather(*tasks)
|
|
return results
|
|
async def combine_audio(task_status: Dict[str, Dict], task_id: str, text: str, language: str , provider:str,host_voice: str , guest_voice:str, podcast_info: PodcastInfo) -> Generator[str, None, None]:
|
|
try:
|
|
dialogue_regex = r'\*\*([\s\S]*?)\*\*[::]\s*([\s\S]*?)(?=\*\*|$)'
|
|
matches = re.findall(dialogue_regex, text, re.DOTALL)
|
|
|
|
lines = [
|
|
{
|
|
"speaker": match[0],
|
|
"content": match[1].strip(),
|
|
}
|
|
for match in matches
|
|
]
|
|
|
|
print("Starting audio generation")
|
|
|
|
|
|
|
|
audio_segments = await process_lines_with_limit(lines,provider, host_voice, guest_voice, 10 if provider=='azure' else 5)
|
|
print("Audio generation completed")
|
|
|
|
|
|
combined_audio = await asyncio.to_thread(sum, audio_segments)
|
|
|
|
print("Audio combined")
|
|
|
|
|
|
unique_filename = f"{uuid.uuid4()}.mp3"
|
|
|
|
os.makedirs(AUDIO_CACHE_DIR, exist_ok=True)
|
|
file_path = os.path.join(AUDIO_CACHE_DIR, unique_filename)
|
|
|
|
|
|
await asyncio.to_thread(combined_audio.export, file_path, format="mp3")
|
|
|
|
audio_url = f"/audio/{unique_filename}"
|
|
|
|
|
|
|
|
asyncio.create_task(upload_to_webdav(file_path, unique_filename, podcast_info.title))
|
|
|
|
|
|
task_status[task_id] = {
|
|
"status": "completed",
|
|
"audio_url": audio_url,
|
|
}
|
|
|
|
|
|
for file in glob.glob(f"{AUDIO_CACHE_DIR}*.mp3"):
|
|
if (
|
|
os.path.isfile(file)
|
|
and time.time() - os.path.getmtime(file) > GRADIO_CLEAR_CACHE_OLDER_THAN
|
|
):
|
|
os.remove, file
|
|
|
|
|
|
clear_pdf_cache()
|
|
return audio_url
|
|
|
|
except Exception as e:
|
|
|
|
task_status[task_id] = {"status": "failed", "error": str(e)}
|
|
|
|
|
|
def generate_podcast_summary(pdf_content: str, text: str, tone: str, length: str, language: str, model_id: str, providerllm: str) -> Generator[str, None, None]:
|
|
modified_system_prompt = get_prompt(pdf_content, text, '', '', '')
|
|
if (modified_system_prompt == False):
|
|
yield json.dumps({
|
|
"type": "error",
|
|
"content": "Prompt is too long"
|
|
}) + "\n"
|
|
return
|
|
stream = call_llm_stream(SUMMARY_INFO_PROMPT, modified_system_prompt, Summary, model_id, providerllm, False)
|
|
full_response = ""
|
|
for chunk in stream:
|
|
|
|
yield json.dumps({"type": "chunk", "content": chunk}) + "\n"
|
|
|
|
yield json.dumps({"type": "final", "content": full_response})
|
|
|
|
def generate_podcast_info(pdfContent: str, text: str, tone: str, length: str, language: str, model_id: str, providerllm: str) -> Generator[str, None, None]:
|
|
|
|
|
|
modified_system_prompt = get_prompt(pdfContent, text, '', '', '')
|
|
|
|
|
|
|
|
if (modified_system_prompt == False):
|
|
yield json.dumps({
|
|
"type": "error",
|
|
"content": "Prompt is too long"
|
|
}) + "\n"
|
|
return
|
|
|
|
full_response = ""
|
|
|
|
|
|
for chunk in call_llm_stream(PODCAST_INFO_PROMPT, modified_system_prompt, PodcastInfo, model_id, providerllm):
|
|
full_response += chunk
|
|
print("播客信息完整响应:", full_response)
|
|
|
|
try:
|
|
|
|
def clean_podcast_json(raw_json_str: str) -> dict:
|
|
|
|
if '```' in raw_json_str:
|
|
matches = re.findall(r'```(?:json)?([\s\S]*?)```', raw_json_str)
|
|
if matches:
|
|
raw_json_str = matches[0]
|
|
|
|
|
|
raw_json_str = raw_json_str.strip()
|
|
|
|
|
|
try:
|
|
result = json.loads(raw_json_str)
|
|
except json.JSONDecodeError:
|
|
|
|
match = re.search(r'{[\s\S]*}', raw_json_str)
|
|
if match:
|
|
try:
|
|
result = json.loads(match.group(0))
|
|
except:
|
|
raise ValueError("Invalid JSON format")
|
|
else:
|
|
raise ValueError("No valid JSON object found")
|
|
|
|
|
|
required_fields = {
|
|
"title": str,
|
|
"host_name": str
|
|
}
|
|
|
|
cleaned_result = {}
|
|
for field, field_type in required_fields.items():
|
|
if field not in result or not isinstance(result[field], field_type):
|
|
raise ValueError(f"Missing or invalid field: {field}")
|
|
cleaned_result[field] = result[field].strip()
|
|
|
|
return cleaned_result
|
|
|
|
result = clean_podcast_json(full_response)
|
|
|
|
print("应该是正确的json:", result)
|
|
|
|
yield json.dumps({
|
|
"type": "podcast_info",
|
|
"content": result
|
|
}) + "\n"
|
|
except Exception as e:
|
|
|
|
yield json.dumps({
|
|
"type": "error",
|
|
"content": f"An unexpected error occurred: {str(e)}"
|
|
}) + "\n"
|
|
|
|
def call_llm_stream(system_prompt: str, text: str, dialogue_format: Any, model_id: str, providerllm: str, isJSON: bool = True) -> Generator[str, None, None]:
|
|
"""Call the LLM with the given prompt and dialogue format, returning a stream of responses."""
|
|
request_params = {
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": text},
|
|
],
|
|
|
|
"model": model_id,
|
|
"max_tokens": FIREWORKS_MAX_TOKENS,
|
|
"temperature": FIREWORKS_TEMPERATURE,
|
|
"stream": True
|
|
}
|
|
|
|
|
|
if isJSON:
|
|
request_params["response_format"] = {
|
|
"type": "json_object",
|
|
"schema": dialogue_format.model_json_schema(),
|
|
}
|
|
stream = fw_client.chat.completions.create(**request_params)
|
|
|
|
full_response = ""
|
|
for chunk in stream:
|
|
if chunk.choices[0].delta.content is not None:
|
|
content = chunk.choices[0].delta.content
|
|
full_response += content
|
|
yield content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def call_llm(system_prompt: str, text: str, dialogue_format: Any, model_id: str, providerllm: str) -> Any:
|
|
"""Call the LLM with the given prompt and dialogue format."""
|
|
response = fw_client.chat.completions.create(
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": text},
|
|
],
|
|
|
|
model=model_id,
|
|
max_tokens=FIREWORKS_MAX_TOKENS,
|
|
temperature=FIREWORKS_TEMPERATURE,
|
|
response_format={
|
|
"type": "json_object",
|
|
"schema": dialogue_format.model_json_schema(),
|
|
},
|
|
)
|
|
return response
|
|
|
|
pdf_cache = {}
|
|
def clear_pdf_cache():
|
|
global pdf_cache
|
|
pdf_cache.clear()
|
|
|
|
def get_link_text(url: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" 通过magic-html-api抓取url内容 """
|
|
api_url = f"https://magic-html-api.vercel.app/api/extract"
|
|
params = {
|
|
"url": url,
|
|
"output_format": "markdown"
|
|
}
|
|
|
|
response = requests.get(api_url, params=params)
|
|
result = response.json()
|
|
|
|
if result["success"]:
|
|
return result["content"]
|
|
else:
|
|
raise Exception("Failed to extract content from URL")
|
|
|
|
async def get_pdf_text(pdf_file: UploadFile):
|
|
text = ""
|
|
try:
|
|
|
|
contents = await pdf_file.read()
|
|
file_hash = hashlib.md5(contents).hexdigest()
|
|
|
|
if file_hash in pdf_cache:
|
|
return pdf_cache[file_hash]
|
|
|
|
|
|
pdf_file_obj = io.BytesIO(contents)
|
|
|
|
|
|
pdf_reader = PdfReader(pdf_file_obj)
|
|
|
|
|
|
text = "\n\n".join([page.extract_text() for page in pdf_reader.pages])
|
|
|
|
|
|
await pdf_file.seek(0)
|
|
|
|
return text
|
|
|
|
except Exception as e:
|
|
return {"error": str(e)}
|
|
|
|
def get_prompt(pdfContent: str, text: str, tone: str, length: str, language: str):
|
|
modified_system_prompt = ""
|
|
new_text = pdfContent +text
|
|
if pdfContent:
|
|
modified_system_prompt += f"\n\n{QUESTION_MODIFIER} {new_text}"
|
|
if tone:
|
|
modified_system_prompt += f"\n\n{TONE_MODIFIER} {tone}."
|
|
if length:
|
|
modified_system_prompt += f"\n\n{LENGTH_MODIFIERS[length]}"
|
|
if language:
|
|
modified_system_prompt += f"\n\n{LANGUAGE_MODIFIER} {language}."
|
|
|
|
return modified_system_prompt
|
|
|