hotdeem commited on
Commit
d16f2df
·
verified ·
1 Parent(s): 633c6cf

Upload 12 files

Browse files
Files changed (12) hide show
  1. Dockerfile +21 -0
  2. README.md +10 -10
  3. api/__init__.py +0 -0
  4. api/main.py +7 -0
  5. api/routes/chat.py +141 -0
  6. constants.py +84 -0
  7. fishaudio.py +33 -0
  8. main.py +36 -0
  9. prompts.py +87 -0
  10. requirements.txt +21 -0
  11. schema.py +34 -0
  12. utils.py +308 -0
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ RUN apt-get update && apt-get install -y git ffmpeg && \
4
+ apt-get clean && rm -rf /var/lib/apt/lists/*
5
+ RUN useradd -m -u 1000 user
6
+ USER user
7
+ ENV HOME=/home/user \
8
+ PATH=/home/user/.local/bin:$PATH
9
+
10
+ # 设置工作目录
11
+ WORKDIR $HOME/app
12
+
13
+ COPY --chown=user requirements.txt .
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+
16
+ COPY --chown=user . .
17
+
18
+ EXPOSE 7860
19
+
20
+ # 运行应用
21
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,10 @@
1
- ---
2
- title: Mp3
3
- emoji: 📉
4
- colorFrom: red
5
- colorTo: pink
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: mp3
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ python_version: 3.12
8
+ app_file: main.py
9
+ pinned: false
10
+ ---
api/__init__.py ADDED
File without changes
api/main.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+
3
+ from api.routes import chat
4
+
5
+ api_router = APIRouter()
6
+ api_router.include_router(chat.router, prefix="/chat")
7
+
api/routes/chat.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from fastapi import APIRouter, BackgroundTasks, Form, HTTPException, UploadFile, File
3
+ from fastapi.responses import StreamingResponse, JSONResponse
4
+ import json
5
+ from typing import Dict, Optional
6
+ from constants import SPEEKERS
7
+ from utils import combine_audio, generate_dialogue, generate_podcast_info, generate_podcast_summary, get_link_text, get_pdf_text
8
+
9
+ router = APIRouter()
10
+
11
+ @router.post("/generate_transcript")
12
+ async def generate_transcript(
13
+ pdfFile: Optional[UploadFile] = File(None),
14
+ textInput: str = Form(...),
15
+ mode: str = Form(...),
16
+ url: Optional[str] = Form(None),
17
+ tone: str = Form(...),
18
+ duration: str = Form(...),
19
+ language: str = Form(...),
20
+
21
+ ):
22
+ pdfContent =""
23
+ if mode=='pdf':
24
+ pdfContent = await get_pdf_text(pdfFile)
25
+ else:
26
+ linkData = get_link_text(url)
27
+ pdfContent = linkData['text']
28
+ new_text = pdfContent
29
+ return StreamingResponse(generate_dialogue(new_text,textInput, tone, duration, language), media_type="application/json")
30
+
31
+
32
+ @router.get("/test")
33
+ def test():
34
+ return {"message": "Hello World"}
35
+
36
+
37
+ @router.get("/speekers")
38
+ def speeker():
39
+ return JSONResponse(content=SPEEKERS)
40
+
41
+ @router.get("/jina")
42
+ def jina():
43
+ result = get_link_text("https://ui.shadcn.com/docs/components/select")
44
+ return JSONResponse(content=result)
45
+
46
+
47
+ @router.post("/summarize")
48
+ async def get_summary(
49
+ textInput: str = Form(...),
50
+ tone: str = Form(...),
51
+ duration: str = Form(...),
52
+ language: str = Form(...),
53
+ mode: str = Form(...),
54
+ url: Optional[str] = Form(None),
55
+ pdfFile: Optional[UploadFile] = File(None)
56
+ ):
57
+ pdfContent =""
58
+ if mode=='pdf':
59
+ pdfContent = await get_pdf_text(pdfFile)
60
+ else:
61
+ linkData = get_link_text(url)
62
+ pdfContent = linkData['text']
63
+ new_text = pdfContent
64
+ return StreamingResponse(
65
+ generate_podcast_summary(
66
+ new_text,
67
+ textInput,
68
+ tone,
69
+ duration,
70
+ language,
71
+ ),
72
+ media_type="application/json"
73
+ )
74
+
75
+ @router.post("/pod_info")
76
+ async def get_pod_info(
77
+ textInput: str = Form(...),
78
+ tone: str = Form(...),
79
+ duration: str = Form(...),
80
+ language: str = Form(...),
81
+ mode: str = Form(...),
82
+ url: Optional[str] = Form(None),
83
+ pdfFile: Optional[UploadFile] = File(None)
84
+ ):
85
+ pdfContent =""
86
+ if mode=='pdf':
87
+ pdfContent = await get_pdf_text(pdfFile)
88
+ else:
89
+ linkData = get_link_text(url)
90
+ pdfContent = linkData['text']
91
+
92
+ new_text = pdfContent[:100]
93
+
94
+ return StreamingResponse(generate_podcast_info(new_text, textInput, tone, duration, language), media_type="application/json")
95
+
96
+
97
+ task_status: Dict[str, Dict] = {}
98
+
99
+
100
+ @router.post("/generate_audio")
101
+ async def audio(
102
+ background_tasks: BackgroundTasks,
103
+ text: str = Form(...),
104
+ host_voice: str = Form(...),
105
+ guest_voice: str = Form(...),
106
+ language: str = Form(...) ,
107
+ provider: str = Form(...)
108
+ ):
109
+ task_id = str(uuid.uuid4())
110
+ task_status[task_id] = {"status": "processing"}
111
+
112
+ background_tasks.add_task(combine_audio, task_status, task_id, text, language,provider , host_voice,guest_voice)
113
+
114
+ return JSONResponse(content={"task_id": task_id, "status": "processing"})
115
+
116
+
117
+ @router.get("/audio_status/{task_id}")
118
+ async def get_audio_status(task_id: str):
119
+ if task_id not in task_status:
120
+ raise HTTPException(status_code=404, detail="Task not found")
121
+
122
+ status = task_status[task_id]
123
+
124
+ if status["status"] == "completed":
125
+ return JSONResponse(content={
126
+ "status": "completed",
127
+ "audio_url": status["audio_url"]
128
+ })
129
+ elif status["status"] == "failed":
130
+ return JSONResponse(content={
131
+ "status": "failed",
132
+ "error": status["error"]
133
+ })
134
+ else:
135
+ return JSONResponse(content={
136
+ "status": "processing"
137
+ })
138
+
139
+
140
+
141
+
constants.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ constants.py
3
+ """
4
+
5
+ import os
6
+
7
+ from pathlib import Path
8
+
9
+ # Key constants
10
+ CHARACTER_LIMIT = 100_000
11
+
12
+ # Gradio-related constants
13
+ GRADIO_CLEAR_CACHE_OLDER_THAN = 1 * 2 * 60 * 60 # 2 hours
14
+
15
+ AUDIO_CACHE_DIR = os.path.join(os.path.dirname(__file__), 'tmp', 'cache')
16
+
17
+ # Error messages-related constants
18
+ ERROR_MESSAGE_NO_INPUT = "Please provide at least one PDF file or a URL."
19
+ ERROR_MESSAGE_NOT_PDF = "The provided file is not a PDF. Please upload only PDF files."
20
+ ERROR_MESSAGE_NOT_SUPPORTED_IN_MELO_TTS = "The selected language is not supported without advanced audio generation. Please enable advanced audio generation or choose a supported language."
21
+ ERROR_MESSAGE_READING_PDF = "Error reading the PDF file"
22
+ ERROR_MESSAGE_TOO_LONG = "The total content is too long. Please ensure the combined text from PDFs and URL is fewer than {CHARACTER_LIMIT} characters."
23
+
24
+ SPEECH_KEY = os.getenv('SPEECH_KEY')
25
+ SPEECH_REGION = "japaneast"
26
+
27
+ FISHAUDIO_KEY = os.getenv('FISHAUDIO_KEY')
28
+ JINA_KEY = os.getenv('JINA_KEY','jina_c1759c7f49e14ced990ac7776800dc44ShJNTXBCizzwjE7IMFYJ6LD960cG')
29
+
30
+ # Fireworks API-related constants
31
+ FIREWORKS_API_KEY = os.getenv('FIREWORKS_API_KEY')
32
+ FIREWORKS_BASE_URL = os.getenv('FIREWORKS_BASE_URL',"https://api.fireworks.ai/inference/v1")
33
+ FIREWORKS_MAX_TOKENS = 16_384
34
+ FIREWORKS_MODEL_ID = os.getenv('FIREWORKS_MODEL_ID',"accounts/fireworks/models/llama-v3p1-405b-instruct")
35
+ FIREWORKS_TEMPERATURE = 0.1
36
+ FIREWORKS_JSON_RETRY_ATTEMPTS = 3
37
+ # Suno related constants
38
+ SUNO_LANGUAGE_MAPPING = {
39
+ "English": "en",
40
+ "Chinese": "zh",
41
+ "French": "fr",
42
+ "German": "de",
43
+ "Hindi": "hi",
44
+ "Italian": "it",
45
+ "Japanese": "ja",
46
+ "Korean": "ko",
47
+ "Polish": "pl",
48
+ "Portuguese": "pt",
49
+ "Russian": "ru",
50
+ "Spanish": "es",
51
+ "Turkish": "tr",
52
+ }
53
+
54
+
55
+ FISHAUDIO_SPEEKER = [
56
+ { "id": "59cb5986671546eaa6ca8ae6f29f6d22", "name": "央视配音" },
57
+ { "id": "738d0cc1a3e9430a9de2b544a466a7fc", "name": "雷军" },
58
+ { "id": "54a5170264694bfc8e9ad98df7bd89c3", "name": "丁真" },
59
+ { "id": "7f92f8afb8ec43bf81429cc1c9199cb1", "name": "AD学姐" },
60
+ { "id": "0eb38bc974e1459facca38b359e13511", "name": "赛马娘" },
61
+ { "id": "e80ea225770f42f79d50aa98be3cedfc", "name": "孙笑川258" },
62
+ { "id": "e4642e5edccd4d9ab61a69e82d4f8a14", "name": "蔡徐坤" },
63
+ { "id": "f7561ff309bd4040a59f1e600f4f4338", "name": "黑手" },
64
+ { "id": "332941d1360c48949f1b4e0cabf912cd", "name": "丁真(锐刻五代版)" },
65
+ { "id": "1aacaeb1b840436391b835fd5513f4c4", "name": "芙宁娜" },
66
+ { "id": "3b55b3d84d2f453a98d8ca9bb24182d6", "name": "邓紫琪" },
67
+ { "id": "7af4d620be1c4c6686132f21940d51c5", "name": "东雪莲" },
68
+ { "id": "e1cfccf59a1c4492b5f51c7c62a8abd2", "name": "永雏塔菲" },
69
+ { "id": "665e031efe27435780ebfa56cc7e0e0d", "name": "月半猫" },
70
+ { "id": "aebaa2305aa2452fbdc8f41eec852a79", "name": "雷军" },
71
+ { "id": "7c66db6e457c4d53b1fe428a8c547953", "name": "郭德纲" },
72
+ { "id": "99503144194c45ed8fb998ceac181dcc", "name": "贝利亚" },
73
+ { "id": "4462fa28f3824bff808a94a6075570e5", "name": "雷军" },
74
+ { "id": "188c9b7c06654042be0e8a25781761e8", "name": "周杰伦" },
75
+ { "id": "6ce7ea8ada884bf3889fa7c7fb206691", "name": "御女茉莉" }
76
+ ]
77
+ SPEEKERS = {
78
+ "fishaudio":FISHAUDIO_SPEEKER,
79
+ "azure":[
80
+ {"id":"zh-CN-YunxiNeural","name":"云希"},
81
+ {"id":"zh-CN-YunzeNeural","name":"云哲"},
82
+ {"id":"zh-CN-YunxuanNeural","name":"晓萱"},
83
+ ]
84
+ }
fishaudio.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fish_audio_sdk import Session, TTSRequest, ReferenceAudio
2
+ from pydub import AudioSegment
3
+ import io
4
+
5
+ from constants import FISHAUDIO_KEY,FISHAUDIO_SPEEKER
6
+
7
+
8
+
9
+ import random
10
+
11
+ def get_adapter_speeker_id(speaker_name):
12
+ speeker = FISHAUDIO_SPEEKER[0]
13
+ if speaker_name != "主持人":
14
+ speeker = random.choice(FISHAUDIO_SPEEKER)
15
+ return speeker["id"]
16
+
17
+ def fishaudio_tts(text, reference_id=None) -> AudioSegment:
18
+ """
19
+ 将给定的文本转换为语音并返回AudioSegment对象。
20
+
21
+ :param text: 要转换的文本
22
+ :param reference_id: 可选参数,使用的模型 ID
23
+ :return: 返回生成的语音的AudioSegment对象
24
+ """
25
+ session = Session(FISHAUDIO_KEY)
26
+ audio_buffer = io.BytesIO()
27
+ for chunk in session.tts(TTSRequest(
28
+ reference_id=reference_id,
29
+ text=text
30
+ )):
31
+ audio_buffer.write(chunk)
32
+ audio_buffer.seek(0) # 重置缓冲区的位置
33
+ return AudioSegment.from_file(audio_buffer, format="mp3")
main.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+
4
+ from fastapi.responses import JSONResponse
5
+ from constants import AUDIO_CACHE_DIR
6
+ from fastapi import FastAPI
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.staticfiles import StaticFiles
9
+ from api.main import api_router
10
+
11
+ app = FastAPI()
12
+
13
+ os.makedirs(AUDIO_CACHE_DIR, exist_ok=True)
14
+ app.mount("/audio", StaticFiles(directory=AUDIO_CACHE_DIR), name="audio")
15
+
16
+ # 添加CORS中间件
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"],
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
+
25
+ app.include_router(api_router, prefix="/api/v1")
26
+
27
+ @app.middleware("http")
28
+ async def add_process_time_header(request, call_next):
29
+ try:
30
+ response = await asyncio.wait_for(call_next(request), timeout=2400) # 4分钟超时
31
+ return response
32
+ except asyncio.TimeoutError:
33
+ return JSONResponse(
34
+ status_code=504,
35
+ content={"detail": "Request processing time exceeded the limit."}
36
+ )
prompts.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ prompts.py
3
+ """
4
+
5
+ SYSTEM_PROMPT = """
6
+ 你是一位世界级的播客制作人,任务是将提供的输入文本转化为引人入胜且内容丰富的播客脚本。输入内容可能是非结构化或杂乱的,来源于PDF或网页。你的目标是提取最有趣、最有洞察力的内容,形成一场引人入胜的播客讨论。
7
+
8
+ 操作步骤:
9
+
10
+ 1. 分析输入:
11
+ 仔细检查文本,识别出关键主题、要点,以及能推动播客对话的有趣事实或轶事。忽略无关的信息或格式问题。
12
+ 2. 编写对话:
13
+ 发展主持人与嘉宾(作者或该主题的专家)之间自然的对话流程,包含:
14
+ • 来自头脑风暴的最佳创意
15
+ • 对复杂话题的清晰解释
16
+ • 引人入胜的、活泼的语气以吸引听众
17
+ • 信息与娱乐的平衡
18
+ 对话规则:
19
+ • 主持人始终发起对话并采访嘉宾
20
+ • 包含主持人引导讨论的深思熟虑的问题
21
+ • 融入自然的口语模式,包括偶尔的语气词(如“嗯”,“好吧”,“你知道”)
22
+ • 允许主持人和嘉宾之间的自然打断和互动
23
+ • 嘉宾的回答必须基于输入文本,避免不支持的说法
24
+ • 保持PG级别的对话,适合所有观众
25
+ • 避免嘉宾的营销或自我推销内容
26
+ • 主持人结束对话
27
+ 3. 总结关键见解:
28
+ 在对话的结尾,自然地融入关键点总结。这应像是随意对话,而不是正式的回顾,强化主要的收获,然后结束。
29
+ 4. 保持真实性:
30
+ 在整个脚本中,努力保持对话的真实性,包含:
31
+ • 主持人表达出真实的好奇或惊讶时刻
32
+ • 嘉宾在表达复杂想法时可能短暂地有些卡顿
33
+ • 适当时加入轻松的时刻或幽默
34
+ • 简短的个人轶事或与主题相关的例子(以输入文本为基础)
35
+ 5. 考虑节奏与结构:
36
+ 确保对话有自然的起伏:
37
+ • 以强有力的引子吸引听众的注意力
38
+ • 随着对话进行,逐渐增加复杂性
39
+ • 包含短暂的“喘息”时刻,让听众消化复杂信息
40
+ • 以有力的方式收尾,或许以发人深省的问题或对听众的号召结束
41
+
42
+ 重要规则:每句对话不应超过100个字符(例如,可以在5-8秒内完成)。
43
+
44
+ 示例格式:
45
+ **Host**: 欢迎来到节目!今天我们讨论的是[话题]。我们的嘉宾是[嘉宾姓名].
46
+ **[Guest Name]**: 谢谢邀请,Jane。我很高兴分享我对[话题]的见解.
47
+
48
+ 记住,在整个对话中保持这种格式。
49
+ """
50
+
51
+ QUESTION_MODIFIER = "请回答这个问题:"
52
+
53
+ TONE_MODIFIER = "语气: 播客的语气应该是"
54
+
55
+ LANGUAGE_MODIFIER = "输出的语言<重要>:播客的语言应该是"
56
+
57
+ LENGTH_MODIFIERS = {
58
+ "short": "保持播客的简短, 大约 1-2 分钟.",
59
+ "medium": "中等长度, 大约 3-5 分钟.",
60
+ }
61
+
62
+
63
+ SUMMARY_INFO_PROMPT = """
64
+ 根据以下输入内容,生成一个播客梗概,使用 markdown 格式,遵循以下具体指南:
65
+
66
+ • 提供播客内容的概述(200-300字)。
67
+ • 突出3个关键点或收获。
68
+
69
+ """
70
+ PODCAST_INFO_PROMPT = """
71
+ 根据以下输入内容,生成一个吸引人的标题和一个富有创意的主持人名字。请遵循以下具体指南:
72
+
73
+ 1. 标题:
74
+ • 创建一个引人入胜且简洁的标题,准确反映播客内容。
75
+ 2. 主持人名字:
76
+ • 为播客主持人创造一个有创意且易记的名字。
77
+
78
+ 请以以下JSON格式提供输出:
79
+
80
+ {
81
+ "title": "An engaging and relevant podcast title",
82
+ "host_name": "A creative name for the host"
83
+ }
84
+
85
+ 确保你的回复是一个有效的 JSON 对象,且不包含其他内容。
86
+
87
+ """
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated-types==0.7.0
2
+ anyio==4.6.0
3
+ click==8.1.7
4
+ fastapi==0.115.0
5
+ h11==0.14.0
6
+ idna==3.10
7
+ pydantic==2.7.0
8
+ pydantic_core==2.18.1
9
+ sniffio==1.3.1
10
+ starlette==0.38.6
11
+ typing_extensions==4.12.2
12
+ uvicorn==0.31.1
13
+ openai==1.50.2
14
+ pydub==0.25.1
15
+ loguru==0.7.2
16
+ suno-bark @ git+https://github.com/suno-ai/bark.git@f4f32d4cd480dfec1c245d258174bc9bde3c2148
17
+ numpy==2.1.1
18
+ python-multipart==0.0.12
19
+ PyPDF2==3.0.1
20
+ azure-cognitiveservices-speech==1.41.1
21
+ fish_audio_sdk
schema.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ schema.py
3
+ """
4
+
5
+ from typing import Literal, List
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+ class Summary(BaseModel):
10
+ """Summary."""
11
+
12
+ summary: str
13
+
14
+ class PodcastInfo(BaseModel):
15
+ """Summary."""
16
+
17
+ title: str
18
+ host_name: str
19
+
20
+
21
+ class DialogueItem(BaseModel):
22
+ """A single dialogue item."""
23
+
24
+ speaker: Literal["Host (Jane)", "Guest"]
25
+ text: str
26
+
27
+
28
+ class ShortDialogue(BaseModel):
29
+ """The dialogue between the host and guest."""
30
+
31
+ name_of_guest: str
32
+ dialogue: List[DialogueItem] = Field(
33
+ ..., description="A list of dialogue items, typically between 11 to 17 items"
34
+ )
utils.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import glob
3
+ import io
4
+ import os
5
+ import re
6
+ import time
7
+ import hashlib
8
+ from typing import Any, Dict, Generator
9
+ import uuid
10
+ from openai import OpenAI
11
+ import requests
12
+ from fishaudio import fishaudio_tts
13
+ from prompts import LANGUAGE_MODIFIER, LENGTH_MODIFIERS, PODCAST_INFO_PROMPT, QUESTION_MODIFIER, SUMMARY_INFO_PROMPT, SYSTEM_PROMPT, TONE_MODIFIER
14
+ import json
15
+ from pydub import AudioSegment
16
+ from fastapi import UploadFile
17
+ from PyPDF2 import PdfReader
18
+ from schema import PodcastInfo, ShortDialogue, Summary
19
+ from constants import (
20
+ AUDIO_CACHE_DIR,
21
+ FIREWORKS_API_KEY,
22
+ FIREWORKS_BASE_URL,
23
+ FIREWORKS_MODEL_ID,
24
+ FIREWORKS_MAX_TOKENS,
25
+ FIREWORKS_TEMPERATURE,
26
+ GRADIO_CLEAR_CACHE_OLDER_THAN,
27
+ JINA_KEY,
28
+ SPEECH_KEY,
29
+ SPEECH_REGION,
30
+ )
31
+ import azure.cognitiveservices.speech as speechsdk
32
+
33
+ fw_client = OpenAI(base_url=FIREWORKS_BASE_URL, api_key=FIREWORKS_API_KEY)
34
+
35
+
36
+
37
+ def generate_dialogue(pdfFile, textInput, tone, duration, language) -> Generator[str, None, None]:
38
+ modified_system_prompt = get_prompt(pdfFile, textInput, tone, duration, language)
39
+ if (modified_system_prompt == False):
40
+ yield json.dumps({
41
+ "type": "error",
42
+ "content": "Prompt is too long"
43
+ }) + "\n"
44
+ return
45
+ full_response = ""
46
+ llm_stream = call_llm_stream(SYSTEM_PROMPT, modified_system_prompt, ShortDialogue, isJSON=False)
47
+
48
+ for chunk in llm_stream:
49
+ yield json.dumps({"type": "chunk", "content": chunk}) + "\n"
50
+ full_response += chunk
51
+
52
+ yield json.dumps({"type": "final", "content": full_response})
53
+
54
+ async def process_line(line, voice,provider):
55
+ if provider == 'fishaudio':
56
+ return await generate_podcast_audio(line['content'], voice)
57
+ return await generate_podcast_audio_by_azure(line['content'], voice)
58
+
59
+ async def generate_podcast_audio_by_azure(text: str, voice: str) -> str:
60
+ try:
61
+ speech_config = speechsdk.SpeechConfig(subscription=SPEECH_KEY, region=SPEECH_REGION)
62
+ speech_config.speech_synthesis_voice_name = voice
63
+
64
+ synthesizer = speechsdk.SpeechSynthesizer(speech_config=speech_config, audio_config=None)
65
+ future =await asyncio.to_thread(synthesizer.speak_text_async, text)
66
+
67
+ result = await asyncio.to_thread(future.get)
68
+
69
+ print("Speech synthesis completed")
70
+
71
+ if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
72
+ print("Audio synthesized successfully")
73
+ audio_data = result.audio_data
74
+ audio_segment = AudioSegment.from_wav(io.BytesIO(audio_data))
75
+ return audio_segment
76
+ else:
77
+ print(f"Speech synthesis failed: {result.reason}")
78
+ if hasattr(result, 'cancellation_details'):
79
+ print(f"Cancellation details: {result.cancellation_details.reason}")
80
+ print(f"Cancellation error details: {result.cancellation_details.error_details}")
81
+ return None
82
+
83
+ except Exception as e:
84
+ print(f"Error in generate_podcast_audio: {e}")
85
+ raise
86
+
87
+ async def generate_podcast_audio(text: str, voice: str) -> str:
88
+ return await generate_podcast_audio_by_fish(text,voice)
89
+
90
+ async def generate_podcast_audio_by_fish(text: str, voice: str) -> str:
91
+ try:
92
+ return fishaudio_tts(text=text,reference_id=voice)
93
+ except Exception as e:
94
+ print(f"Error in generate_podcast_audio: {e}")
95
+ raise
96
+ async def process_lines_with_limit(lines, provider , host_voice, guest_voice, max_concurrency):
97
+ semaphore = asyncio.Semaphore(max_concurrency)
98
+
99
+ async def limited_process_line(line):
100
+ async with semaphore:
101
+ voice = host_voice if (line['speaker'] == '主持人' or line['speaker'] == 'Host') else guest_voice
102
+ return await process_line(line, voice , provider)
103
+
104
+ tasks = [limited_process_line(line) for line in lines]
105
+ results = await asyncio.gather(*tasks)
106
+ return results
107
+ async def combine_audio(task_status: Dict[str, Dict], task_id: str, text: str, language: str , provider:str,host_voice: str , guest_voice:str) -> Generator[str, None, None]:
108
+ try:
109
+ dialogue_regex = r'\*\*([\s\S]*?)\*\*[::]\s*([\s\S]*?)(?=\*\*|$)'
110
+ matches = re.findall(dialogue_regex, text, re.DOTALL)
111
+
112
+ lines = [
113
+ {
114
+ "speaker": match[0],
115
+ "content": match[1].strip(),
116
+ }
117
+ for match in matches
118
+ ]
119
+
120
+ print("Starting audio generation")
121
+ # audio_segments = await asyncio.gather(
122
+ # *[process_line(line, host_voice if line['speaker'] == '主持人' else guest_voice) for line in lines]
123
+ # )
124
+ audio_segments = await process_lines_with_limit(lines,provider, host_voice, guest_voice, 10 if provider=='azure' else 5)
125
+ print("Audio generation completed")
126
+
127
+ # 合并音频
128
+ combined_audio = await asyncio.to_thread(sum, audio_segments)
129
+
130
+ print("Audio combined")
131
+
132
+ # 只在最后写入文件
133
+ unique_filename = f"{uuid.uuid4()}.mp3"
134
+
135
+ os.makedirs(AUDIO_CACHE_DIR, exist_ok=True)
136
+ file_path = os.path.join(AUDIO_CACHE_DIR, unique_filename)
137
+
138
+ # 异步导出音频文件
139
+ await asyncio.to_thread(combined_audio.export, file_path, format="mp3")
140
+
141
+ audio_url = f"/audio/{unique_filename}"
142
+ task_status[task_id] = {"status": "completed", "audio_url": audio_url}
143
+
144
+ for file in glob.glob(f"{AUDIO_CACHE_DIR}*.mp3"):
145
+ if (
146
+ os.path.isfile(file)
147
+ and time.time() - os.path.getmtime(file) > GRADIO_CLEAR_CACHE_OLDER_THAN
148
+ ):
149
+ os.remove, file
150
+
151
+
152
+ clear_pdf_cache()
153
+ return audio_url
154
+
155
+ except Exception as e:
156
+ # 如果发生错误,更新状态为失败
157
+ task_status[task_id] = {"status": "failed", "error": str(e)}
158
+
159
+
160
+ def generate_podcast_summary(pdf_content: str, text: str, tone: str, length: str, language: str) -> Generator[str, None, None]:
161
+ modified_system_prompt = get_prompt(pdf_content, text, '', '', '')
162
+ if (modified_system_prompt == False):
163
+ yield json.dumps({
164
+ "type": "error",
165
+ "content": "Prompt is too long"
166
+ }) + "\n"
167
+ return
168
+ stream = call_llm_stream(SUMMARY_INFO_PROMPT, modified_system_prompt, Summary, False)
169
+ full_response = ""
170
+ for chunk in stream:
171
+ # 将每个 chunk 作为 JSON 字符串 yield
172
+ yield json.dumps({"type": "chunk", "content": chunk}) + "\n"
173
+
174
+ yield json.dumps({"type": "final", "content": full_response})
175
+
176
+ def generate_podcast_info(pdfContent: str, text: str, tone: str, length: str, language: str) -> Generator[str, None, None]:
177
+ modified_system_prompt = get_prompt(pdfContent, text, '', '', '')
178
+ if (modified_system_prompt == False):
179
+ yield json.dumps({
180
+ "type": "error",
181
+ "content": "Prompt is too long"
182
+ }) + "\n"
183
+ return
184
+
185
+ full_response = ""
186
+ for chunk in call_llm_stream(PODCAST_INFO_PROMPT, modified_system_prompt, PodcastInfo):
187
+ full_response += chunk
188
+ try:
189
+ result = json.loads(full_response)
190
+
191
+ yield json.dumps({
192
+ "type": "podcast_info",
193
+ "content": result
194
+ }) + "\n"
195
+ except Exception as e:
196
+ yield json.dumps({
197
+ "type": "error",
198
+ "content": f"An unexpected error occurred: {str(e)}"
199
+ }) + "\n"
200
+
201
+ def call_llm_stream(system_prompt: str, text: str, dialogue_format: Any, isJSON: bool = True) -> Generator[str, None, None]:
202
+ """Call the LLM with the given prompt and dialogue format, returning a stream of responses."""
203
+ request_params = {
204
+ "messages": [
205
+ {"role": "system", "content": system_prompt},
206
+ {"role": "user", "content": text},
207
+ ],
208
+ "model": FIREWORKS_MODEL_ID,
209
+ "max_tokens": FIREWORKS_MAX_TOKENS,
210
+ "temperature": FIREWORKS_TEMPERATURE,
211
+ "stream": True # 启用流式输出
212
+ }
213
+
214
+ # 如果需要 JSON 响应,添加 response_format 参数
215
+ if isJSON:
216
+ request_params["response_format"] = {
217
+ "type": "json_object",
218
+ "schema": dialogue_format.model_json_schema(),
219
+ }
220
+ stream = fw_client.chat.completions.create(**request_params)
221
+
222
+ full_response = ""
223
+ for chunk in stream:
224
+ if chunk.choices[0].delta.content is not None:
225
+ content = chunk.choices[0].delta.content
226
+ full_response += content
227
+ yield content
228
+
229
+ # 在流结束时,尝试解析完整的 JSON 响应
230
+ # try:
231
+ # parsed_response = json.loads(full_response)
232
+ # yield json.dumps({"type": "final", "content": parsed_response})
233
+ # except json.JSONDecodeError:
234
+ # yield json.dumps({"type": "error", "content": "Failed to parse JSON response"})
235
+
236
+ def call_llm(system_prompt: str, text: str, dialogue_format: Any) -> Any:
237
+ """Call the LLM with the given prompt and dialogue format."""
238
+ response = fw_client.chat.completions.create(
239
+ messages=[
240
+ {"role": "system", "content": system_prompt},
241
+ {"role": "user", "content": text},
242
+ ],
243
+ model=FIREWORKS_MODEL_ID,
244
+ max_tokens=FIREWORKS_MAX_TOKENS,
245
+ temperature=FIREWORKS_TEMPERATURE,
246
+ response_format={
247
+ "type": "json_object",
248
+ "schema": dialogue_format.model_json_schema(),
249
+ },
250
+ )
251
+ return response
252
+
253
+ pdf_cache = {}
254
+ def clear_pdf_cache():
255
+ global pdf_cache
256
+ pdf_cache.clear()
257
+
258
+ def get_link_text(url: str):
259
+ """ 通过jina.ai 抓取url内容 """
260
+ url = f"https://r.jina.ai/{url}"
261
+ headers = {}
262
+ headers['Authorization'] = 'Bearer ' + JINA_KEY
263
+ headers['Accept'] = 'application/json'
264
+ headers['X-Return-Format'] = 'text'
265
+ response = requests.get(url, headers=headers)
266
+ return response.json()['data']
267
+
268
+ async def get_pdf_text(pdf_file: UploadFile):
269
+ text = ""
270
+ print(pdf_file)
271
+ try:
272
+ # 读取上传文件的内容
273
+ contents = await pdf_file.read()
274
+ file_hash = hashlib.md5(contents).hexdigest()
275
+
276
+ if file_hash in pdf_cache:
277
+ return pdf_cache[file_hash]
278
+
279
+ # 使用 BytesIO 创建一个内存中的文件对象
280
+ pdf_file_obj = io.BytesIO(contents)
281
+
282
+ # 使用 PdfReader 读取 PDF 内容
283
+ pdf_reader = PdfReader(pdf_file_obj)
284
+
285
+ # 提取所有页面的文本
286
+ text = "\n\n".join([page.extract_text() for page in pdf_reader.pages])
287
+
288
+ # 重置文件指针,以防后续还需要读取文件
289
+ await pdf_file.seek(0)
290
+
291
+ return text
292
+
293
+ except Exception as e:
294
+ return {"error": str(e)}
295
+
296
+ def get_prompt(pdfContent: str, text: str, tone: str, length: str, language: str):
297
+ modified_system_prompt = ""
298
+ new_text = pdfContent +text
299
+ if pdfContent:
300
+ modified_system_prompt += f"\n\n{QUESTION_MODIFIER} {new_text}"
301
+ if tone:
302
+ modified_system_prompt += f"\n\n{TONE_MODIFIER} {tone}."
303
+ if length:
304
+ modified_system_prompt += f"\n\n{LENGTH_MODIFIERS[length]}"
305
+ if language:
306
+ modified_system_prompt += f"\n\n{LANGUAGE_MODIFIER} {language}."
307
+
308
+ return modified_system_prompt