Spaces:
Running
Running
import os | |
import json | |
import re | |
import gradio as gr | |
import asyncio | |
import logging | |
import torch | |
from serpapi import GoogleSearch | |
from pydantic import BaseModel | |
from autogen_agentchat.agents import AssistantAgent | |
from autogen_agentchat.conditions import HandoffTermination, TextMentionTermination | |
from autogen_agentchat.teams import Swarm | |
from autogen_agentchat.ui import Console | |
from autogen_agentchat.messages import TextMessage, HandoffMessage, StructuredMessage | |
from autogen_ext.models.anthropic import AnthropicChatCompletionClient | |
from autogen_ext.models.openai import OpenAIChatCompletionClient | |
from autogen_ext.models.ollama import OllamaChatCompletionClient | |
import traceback | |
import soundfile as sf | |
import tempfile | |
from pydub import AudioSegment | |
from TTS.api import TTS | |
# Set up logging | |
logging.basicConfig( | |
level=logging.DEBUG, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
handlers=[ | |
logging.FileHandler("lecture_generation.log"), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# Set up environment | |
OUTPUT_DIR = os.path.join(os.getcwd(), "outputs") # Fallback for local dev | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
logger.info(f"Using output directory: {OUTPUT_DIR}") | |
os.environ["COQUI_TOS_AGREED"] = "1" | |
# Initialize TTS model at the top | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device) | |
logger.info("TTS model initialized on %s", device) | |
except Exception as e: | |
logger.error("Failed to initialize TTS model: %s", str(e)) | |
tts = None | |
# Define Pydantic model for slide data | |
class Slide(BaseModel): | |
title: str | |
content: str | |
class SlidesOutput(BaseModel): | |
slides: list[Slide] | |
# Define search_web tool using SerpApi | |
def search_web(query: str, serpapi_key: str) -> str: | |
try: | |
params = { | |
"q": query, | |
"engine": "google", | |
"api_key": serpapi_key, | |
"num": 5 | |
} | |
search = GoogleSearch(params) | |
results = search.get_dict() | |
if "error" in results: | |
logger.error("SerpApi error: %s", results["error"]) | |
return None | |
if "organic_results" not in results or not results["organic_results"]: | |
logger.info("No search results found for query: %s", query) | |
return None | |
formatted_results = [] | |
for item in results["organic_results"][:5]: | |
title = item.get("title", "No title") | |
snippet = item.get("snippet", "No snippet") | |
link = item.get("link", "No link") | |
formatted_results.append(f"Title: {title}\nSnippet: {snippet}\nLink: {link}\n") | |
formatted_output = "\n".join(formatted_results) | |
logger.info("Successfully retrieved search results for query: %s", query) | |
return formatted_output | |
except Exception as e: | |
logger.error("Unexpected error during search: %s", str(e)) | |
return None | |
# Define helper function for progress HTML | |
def html_with_progress(label, progress): | |
return f""" | |
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;"> | |
<div style="width: 70%; background-color: #FFFFFF; border-radius: 80px; overflow: hidden; margin-bottom: 20px;"> | |
<div style="width: {progress}%; height: 15px; background-color: #4CAF50; border-radius: 80px;"></div> | |
</div> | |
<h2 style="font-style: italic; color: #555;">{label}</h2> | |
</div> | |
""" | |
# Function to get model client based on selected service | |
def get_model_client(service, api_key): | |
if service == "OpenAI-gpt-4o-2024-08-06": | |
return OpenAIChatCompletionClient(model="gpt-4o-2024-08-06", api_key=api_key) | |
elif service == "Anthropic-claude-3-sonnet-20240229": | |
return AnthropicChatCompletionClient(model="claude-3-sonnet-20240229", api_key=api_key) | |
elif service == "Google-gemini-1.5-flash": | |
return OpenAIChatCompletionClient(model="gemini-1.5-flash", api_key=api_key) | |
elif service == "Ollama-llama3.2": | |
return OllamaChatCompletionClient(model="llama3.2") | |
else: | |
raise ValueError("Invalid service") | |
# Helper function to clean script text | |
def clean_script_text(script): | |
if not script or not isinstance(script, str): | |
logger.error("Invalid script input: %s", script) | |
return None | |
script = re.sub(r"\*\*Slide \d+:.*?\*\*", "", script) | |
script = re.sub(r"\[.*?\]", "", script) | |
script = re.sub(r"Title:.*?\n|Content:.*?\n", "", script) | |
script = script.replace("humanlike", "human-like").replace("problemsolving", "problem-solving") | |
script = re.sub(r"\s+", " ", script).strip() | |
if len(script) < 10: | |
logger.error("Cleaned script too short (%d characters): %s", len(script), script) | |
return None | |
logger.info("Cleaned script: %s", script) | |
return script | |
# Helper function to validate and convert speaker audio | |
async def validate_and_convert_speaker_audio(speaker_audio): | |
if not speaker_audio or not os.path.exists(speaker_audio): | |
logger.warning("Speaker audio file does not exist: %s. Using default voice.", speaker_audio) | |
default_voice = os.path.join(os.path.dirname(__file__), "feynman.mp3") | |
if os.path.exists(default_voice): | |
speaker_audio = default_voice | |
else: | |
logger.error("Default voice not found. Cannot proceed with TTS.") | |
return None | |
try: | |
ext = os.path.splitext(speaker_audio)[1].lower() | |
if ext == ".mp3": | |
logger.info("Converting MP3 to WAV: %s", speaker_audio) | |
audio = AudioSegment.from_mp3(speaker_audio) | |
audio = audio.set_channels(1).set_frame_rate(22050) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False, dir=OUTPUT_DIR) as temp_file: | |
audio.export(temp_file.name, format="wav") | |
speaker_wav = temp_file.name | |
elif ext == ".wav": | |
speaker_wav = speaker_audio | |
else: | |
logger.error("Unsupported audio format: %s", ext) | |
return None | |
data, samplerate = sf.read(speaker_wav) | |
if samplerate < 16000 or samplerate > 48000: | |
logger.error("Invalid sample rate for %s: %d Hz", speaker_wav, samplerate) | |
return None | |
if len(data) < 16000: | |
logger.error("Speaker audio too short: %d frames", len(data)) | |
return None | |
if data.ndim == 2: | |
logger.info("Converting stereo WAV to mono: %s", speaker_wav) | |
data = data.mean(axis=1) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False, dir=OUTPUT_DIR) as temp_file: | |
sf.write(temp_file.name, data, samplerate) | |
speaker_wav = temp_file.name | |
logger.info("Validated speaker audio: %s", speaker_wav) | |
return speaker_wav | |
except Exception as e: | |
logger.error("Failed to validate or convert speaker audio %s: %s", speaker_audio, str(e)) | |
return None | |
# Helper function to generate audio using Coqui TTS API | |
def generate_xtts_audio(tts, text, speaker_wav, output_path): | |
if not tts: | |
logger.error("TTS model not initialized") | |
return False | |
try: | |
tts.tts_to_file(text=text, speaker_wav=speaker_wav, language="en", file_path=output_path) | |
logger.info("Generated audio for %s", output_path) | |
return True | |
except Exception as e: | |
logger.error("Failed to generate audio for %s: %s", output_path, str(e)) | |
return False | |
# Helper function to extract JSON from messages | |
def extract_json_from_message(message): | |
if isinstance(message, TextMessage): | |
content = message.content | |
logger.debug("Extracting JSON from TextMessage: %s", content) | |
if not isinstance(content, str): | |
logger.warning("TextMessage content is not a string: %s", content) | |
return None | |
pattern = r"```json\s*(.*?)\s*```" | |
match = re.search(pattern, content, re.DOTALL) | |
if match: | |
try: | |
json_str = match.group(1).strip() | |
logger.debug("Found JSON in code block: %s", json_str) | |
return json.loads(json_str) | |
except json.JSONDecodeError as e: | |
logger.error("Failed to parse JSON from code block: %s", e) | |
json_patterns = [ | |
r"\[\s*\{.*?\}\s*\]", | |
r"\{\s*\".*?\"\s*:.*?\}", | |
] | |
for pattern in json_patterns: | |
match = re.search(pattern, content, re.DOTALL) | |
if match: | |
try: | |
json_str = match.group(0).strip() | |
logger.debug("Found JSON with pattern %s: %s", pattern, json_str) | |
return json.loads(json_str) | |
except json.JSONDecodeError as e: | |
logger.error("Failed to parse JSON with pattern %s: %s", pattern, e) | |
try: | |
for i in range(len(content)): | |
for j in range(len(content), i, -1): | |
substring = content[i:j].strip() | |
if (substring.startswith('{') and substring.endswith('}')) or \ | |
(substring.startswith('[') and substring.endswith(']')): | |
try: | |
parsed = json.loads(substring) | |
if isinstance(parsed, (list, dict)): | |
logger.info("Found JSON in substring: %s", substring) | |
return parsed | |
except json.JSONDecodeError: | |
continue | |
except Exception as e: | |
logger.error("Error in JSON substring search: %s", e) | |
logger.warning("No JSON found in TextMessage content") | |
return None | |
elif isinstance(message, StructuredMessage): | |
content = message.content | |
logger.debug("Extracting JSON from StructuredMessage: %s", content) | |
try: | |
if isinstance(content, BaseModel): | |
content_dict = content.dict() | |
return content_dict.get("slides", content_dict) | |
return content | |
except Exception as e: | |
logger.error("Failed to extract JSON from StructuredMessage: %s, Content: %s", e, content) | |
return None | |
elif isinstance(message, HandoffMessage): | |
logger.debug("Extracting JSON from HandoffMessage context") | |
for ctx_msg in message.context: | |
if hasattr(ctx_msg, "content"): | |
content = ctx_msg.content | |
logger.debug("HandoffMessage context content: %s", content) | |
if isinstance(content, str): | |
pattern = r"```json\s*(.*?)\s*```" | |
match = re.search(pattern, content, re.DOTALL) | |
if match: | |
try: | |
return json.loads(match.group(1)) | |
except json.JSONDecodeError as e: | |
logger.error("Failed to parse JSON from HandoffMessage: %s", e) | |
json_patterns = [ | |
r"\[\s*\{.*?\}\s*\]", | |
r"\{\s*\".*?\"\s*:.*?\}", | |
] | |
for pattern in json_patterns: | |
match = re.search(pattern, content, re.DOTALL) | |
if match: | |
try: | |
return json.loads(match.group(0)) | |
except json.JSONDecodeError as e: | |
logger.error("Failed to parse JSON with pattern %s: %s", pattern, e) | |
elif isinstance(content, dict): | |
return content.get("slides", content) | |
logger.warning("No JSON found in HandoffMessage context") | |
return None | |
logger.warning("Unsupported message type for JSON extraction: %s", type(message)) | |
return None | |
# Function to generate Markdown slides | |
def generate_markdown_slides(slides, title, speaker="Prof. AI Feynman", date="April 26th, 2025"): | |
try: | |
markdown_slides = [] | |
for i, slide in enumerate(slides): | |
slide_number = i + 1 | |
content = slide['content'] | |
# First slide has no header/footer, others have header and footer | |
if i == 0: | |
slide_md = f""" | |
# {slide['title']} | |
{content} | |
**{speaker}** | |
*{date}* | |
""" | |
else: | |
slide_md = f""" | |
##### Slide {slide_number}, {slide['title']} | |
{content} | |
, {title} {speaker}, {date} | |
""" | |
markdown_slides.append(slide_md.strip()) | |
logger.info(f"Generated Markdown slides for: {title}: {markdown_slides}") | |
return markdown_slides | |
except Exception as e: | |
logger.error(f"Failed to generate Markdown slides: {str(e)}") | |
logger.error(traceback.format_exc()) | |
return None | |
# Async function to update audio preview | |
async def update_audio_preview(audio_file): | |
if audio_file: | |
logger.info("Updating audio preview for file: %s", audio_file) | |
return audio_file | |
return None | |
# Async function to generate lecture materials and audio | |
async def on_generate(api_service, api_key, serpapi_key, title, lecture_content_description, lecture_type, speaker_audio, num_slides): | |
model_client = get_model_client(api_service, api_key) | |
total_slides = num_slides # Use exactly the number of slides from input | |
research_agent = AssistantAgent( | |
name="research_agent", | |
model_client=model_client, | |
handoffs=["slide_agent"], | |
system_message="You are a Research Agent. Use the search_web tool to gather information on the topic and keywords from the initial message. Summarize the findings concisely in a single message, then use the handoff_to_slide_agent tool to pass the task to the Slide Agent. Do not produce any other output.", | |
tools=[search_web] | |
) | |
slide_agent = AssistantAgent( | |
name="slide_agent", | |
model_client=model_client, | |
handoffs=["script_agent"], | |
system_message=f""" | |
You are a Slide Agent. Using the research from the conversation history and the specified number of slides ({total_slides}), generate exactly {total_slides} content slides. Output ONLY a JSON array wrapped in ```json ... ``` in a TextMessage, where each slide is an object with 'title' and 'content' keys. Do not include any explanatory text, comments, or other messages. Ensure the JSON is valid and contains exactly {total_slides} slides before proceeding. After outputting the JSON, use the handoff_to_script_agent tool to pass the task to the Script Agent. | |
Example output for 2 slides: | |
```json | |
[ | |
{{"title": "Slide 1", "content": "Content for slide 1"}}, | |
{{"title": "Slide 2", "content": "Content for slide 2"}} | |
] | |
```""", | |
output_content_type=None, | |
reflect_on_tool_use=False | |
) | |
script_agent = AssistantAgent( | |
name="script_agent", | |
model_client=model_client, | |
handoffs=["feynman_agent"], | |
system_message=f""" | |
You are a Script Agent. Access the JSON array of {total_slides} slides from the conversation history. Generate a narration script (1-2 sentences) for each of the {total_slides} slides, summarizing its content in a clear, academically inclined tone as a professor would deliver it. Avoid using non-verbal fillers such as "um," "you know," or "like." Output ONLY a JSON array wrapped in ```json ... ``` with exactly {total_slides} strings, one script per slide, in the same order. Ensure the JSON is valid and complete. After outputting, use the handoff_to_feynman_agent tool. If scripts cannot be generated, retry once. | |
Example for 3 slides: | |
```json | |
[ | |
"Hello everyone, welcome to Agents 101. I am Jaward, your primary instructor for this course.", | |
"Today, we will cover the syllabus for this semester, providing a gentle introduction to AI agents.", | |
"Let us define what an AI agent is: it refers to a system or program capable of autonomously performing tasks on behalf of a user or another system." | |
] | |
```""", | |
output_content_type=None, | |
reflect_on_tool_use=False | |
) | |
feynman_agent = AssistantAgent( | |
name="feynman_agent", | |
model_client=model_client, | |
handoffs=[], | |
system_message=f""" | |
You are Agent Feynman. Review the slides and scripts from the conversation history to ensure coherence, completeness, and that exactly {total_slides} slides and {total_slides} scripts are received. Output a confirmation message summarizing the number of slides and scripts received. If slides or scripts are missing, invalid, or do not match the expected count ({total_slides}), report the issue clearly. Use 'TERMINATE' to signal completion. | |
Example: 'Received {total_slides} slides and {total_slides} scripts. Lecture is coherent. TERMINATE' | |
""") | |
swarm = Swarm( | |
participants=[research_agent, slide_agent, script_agent, feynman_agent], | |
termination_condition=HandoffTermination(target="user") | TextMentionTermination("TERMINATE") | |
) | |
progress = 0 | |
label = "Research: in progress..." | |
yield ( | |
html_with_progress(label, progress), | |
[] | |
) | |
await asyncio.sleep(0.1) | |
initial_message = f""" | |
Lecture Title: {title} | |
Lecture Content Description: {lecture_content_description} | |
Audience: {lecture_type} | |
Number of Slides: {total_slides} | |
Please start by researching the topic, or proceed without research if search is unavailable. | |
""" | |
logger.info("Starting lecture generation for title: %s with %d slides", title, total_slides) | |
slides = None | |
scripts = None | |
error_html = """ | |
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;"> | |
<h2 style="color: #d9534f;">Failed to generate lecture materials</h2> | |
<p style="margin-top: 20px;">Please try again with different parameters or a different model.</p> | |
</div> | |
""" | |
try: | |
logger.info("Research Agent starting...") | |
if serpapi_key: | |
task_result = await Console(swarm.run_stream(task=initial_message)) | |
else: | |
logger.warning("No SerpApi key provided, bypassing research phase") | |
task_result = await Console(swarm.run_stream(task=f"{initial_message}\nNo search available, proceed with slide generation.")) | |
logger.info("Swarm execution completed") | |
slide_retry_count = 0 | |
script_retry_count = 0 | |
max_retries = 2 | |
for message in task_result.messages: | |
source = getattr(message, 'source', getattr(message, 'sender', None)) | |
logger.debug("Processing message from %s, type: %s", source, type(message)) | |
if isinstance(message, HandoffMessage): | |
logger.info("Handoff from %s to %s", source, message.target) | |
if source == "research_agent" and message.target == "slide_agent": | |
progress = 25 | |
label = "Slides: generating..." | |
yield ( | |
html_with_progress(label, progress), | |
[] | |
) | |
await asyncio.sleep(0.1) | |
elif source == "slide_agent" and message.target == "script_agent": | |
if slides is None: | |
logger.warning("Slide Agent handoff without slides JSON") | |
extracted_json = extract_json_from_message(message) | |
if extracted_json: | |
slides = extracted_json | |
logger.info("Extracted slides JSON from HandoffMessage context: %s", slides) | |
if slides is None or len(slides) != total_slides: | |
if slide_retry_count < max_retries: | |
slide_retry_count += 1 | |
logger.info("Retrying slide generation (attempt %d/%d)", slide_retry_count, max_retries) | |
retry_message = TextMessage( | |
content=f"Please generate exactly {total_slides} slides as per your instructions.", | |
source="user", | |
recipient="slide_agent" | |
) | |
task_result.messages.append(retry_message) | |
continue | |
progress = 50 | |
label = "Scripts: generating..." | |
yield ( | |
html_with_progress(label, progress), | |
[] | |
) | |
await asyncio.sleep(0.1) | |
elif source == "script_agent" and message.target == "feynman_agent": | |
if scripts is None: | |
logger.warning("Script Agent handoff without scripts JSON") | |
extracted_json = extract_json_from_message(message) | |
if extracted_json: | |
scripts = extracted_json | |
logger.info("Extracted scripts JSON from HandoffMessage context: %s", scripts) | |
progress = 75 | |
label = "Review: in progress..." | |
yield ( | |
html_with_progress(label, progress), | |
[] | |
) | |
await asyncio.sleep(0.1) | |
elif source == "research_agent" and isinstance(message, TextMessage) and "handoff_to_slide_agent" in message.content: | |
logger.info("Research Agent completed research") | |
progress = 25 | |
label = "Slides: generating..." | |
yield ( | |
html_with_progress(label, progress), | |
[] | |
) | |
await asyncio.sleep(0.1) | |
elif source == "slide_agent" and isinstance(message, (TextMessage, StructuredMessage)): | |
logger.debug("Slide Agent message received") | |
extracted_json = extract_json_from_message(message) | |
if extracted_json: | |
slides = extracted_json | |
logger.info("Slide Agent generated %d slides: %s", len(slides), slides) | |
if len(slides) != total_slides: | |
if slide_retry_count < max_retries: | |
slide_retry_count += 1 | |
logger.info("Retrying slide generation (attempt %d/%d)", slide_retry_count, max_retries) | |
retry_message = TextMessage( | |
content=f"Please generate exactly {total_slides} slides as per your instructions.", | |
source="user", | |
recipient="slide_agent" | |
) | |
task_result.messages.append(retry_message) | |
continue | |
for i, slide in enumerate(slides): | |
content_file = os.path.join(OUTPUT_DIR, f"slide_{i+1}_content.txt") | |
try: | |
with open(content_file, "w", encoding="utf-8") as f: | |
f.write(slide["content"]) | |
logger.info("Saved slide content to %s", content_file) | |
except Exception as e: | |
logger.error("Error saving slide content to %s: %s", content_file, str(e)) | |
progress = 50 | |
label = "Scripts: generating..." | |
yield ( | |
html_with_progress(label, progress), | |
[] | |
) | |
await asyncio.sleep(0.1) | |
else: | |
logger.warning("No JSON extracted from slide_agent message") | |
if slide_retry_count < max_retries: | |
slide_retry_count += 1 | |
logger.info("Retrying slide generation (attempt %d/%d)", slide_retry_count, max_retries) | |
retry_message = TextMessage( | |
content=f"Please generate exactly {total_slides} slides as per your instructions.", | |
source="user", | |
recipient="slide_agent" | |
) | |
task_result.messages.append(retry_message) | |
continue | |
elif source == "script_agent" and isinstance(message, (TextMessage, StructuredMessage)): | |
logger.debug("Script Agent message received") | |
extracted_json = extract_json_from_message(message) | |
if extracted_json: | |
scripts = extracted_json | |
logger.info("Script Agent generated scripts for %d slides: %s", len(scripts), scripts) | |
for i, script in enumerate(scripts): | |
script_file = os.path.join(OUTPUT_DIR, f"slide_{i+1}_script.txt") | |
try: | |
with open(script_file, "w", encoding="utf-8") as f: | |
f.write(script) | |
logger.info("Saved script to %s", script_file) | |
except Exception as e: | |
logger.error("Error saving script to %s: %s", script_file, str(e)) | |
progress = 75 | |
label = "Scripts generated and saved. Reviewing..." | |
yield ( | |
html_with_progress(label, progress), | |
[] | |
) | |
await asyncio.sleep(0.1) | |
else: | |
logger.warning("No JSON extracted from script_agent message") | |
if script_retry_count < max_retries: | |
script_retry_count += 1 | |
logger.info("Retrying script generation (attempt %d/%d)", script_retry_count, max_retries) | |
retry_message = TextMessage( | |
content=f"Please generate exactly {total_slides} scripts for the {total_slides} slides as per your instructions.", | |
source="user", | |
recipient="script_agent" | |
) | |
task_result.messages.append(retry_message) | |
continue | |
elif source == "feynman_agent" and isinstance(message, TextMessage) and "TERMINATE" in message.content: | |
logger.info("Feynman Agent completed lecture review: %s", message.content) | |
progress = 90 | |
label = "Lecture materials ready. Generating audio..." | |
yield ( | |
html_with_progress(label, progress), | |
[] | |
) | |
await asyncio.sleep(0.1) | |
logger.info("Slides state: %s", "Generated" if slides else "None") | |
logger.info("Scripts state: %s", "Generated" if scripts else "None") | |
if not slides or not scripts: | |
error_message = f"Failed to generate {'slides and scripts' if not slides and not scripts else 'slides' if not slides else 'scripts'}" | |
error_message += f". Received {len(slides) if slides else 0} slides and {len(scripts) if scripts else 0} scripts." | |
logger.error("%s", error_message) | |
logger.debug("Dumping all messages for debugging:") | |
for msg in task_result.messages: | |
source = getattr(msg, 'source', getattr(msg, 'sender', None)) | |
logger.debug("Message from %s, type: %s, content: %s", source, type(msg), msg.to_text() if hasattr(msg, 'to_text') else str(msg)) | |
yield ( | |
f""" | |
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;"> | |
<h2 style="color: #d9534f;">{error_message}</h2> | |
<p style="margin-top: 20px;">Please try again with a different model or adjust your inputs.</p> | |
</div> | |
""", | |
[] | |
) | |
return | |
if len(slides) != total_slides: | |
logger.error("Expected %d slides, but received %d", total_slides, len(slides)) | |
yield ( | |
f""" | |
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;"> | |
<h2 style="color: #d9534f;">Incorrect number of slides</h2> | |
<p style="margin-top: 20px;">Expected {total_slides} slides, but generated {len(slides)}. Please try again.</p> | |
</div> | |
""", | |
[] | |
) | |
return | |
if not isinstance(scripts, list) or not all(isinstance(s, str) for s in scripts): | |
logger.error("Scripts are not a list of strings: %s", scripts) | |
yield ( | |
f""" | |
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;"> | |
<h2 style="color: #d9534f;">Invalid script format</h2> | |
<p style="margin-top: 20px;">Scripts must be a list of strings. Please try again.</p> | |
</div> | |
""", | |
[] | |
) | |
return | |
if len(scripts) != total_slides: | |
logger.error("Mismatch between number of slides (%d) and scripts (%d)", len(slides), len(scripts)) | |
yield ( | |
f""" | |
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;"> | |
<h2 style="color: #d9534f;">Mismatch in slides and scripts</h2> | |
<p style="margin-top: 20px;">Generated {len(slides)} slides but {len(scripts)} scripts. Please try again.</p> | |
</div> | |
""", | |
[] | |
) | |
return | |
markdown_slides = generate_markdown_slides(slides, title) | |
if not markdown_slides: | |
logger.error("Failed to generate Markdown slides") | |
yield ( | |
f""" | |
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;"> | |
<h2 style="color: #d9534f;">Failed to generate slides</h2> | |
<p style="margin-top: 20px;">Please try again.</p> | |
</div> | |
""", | |
[] | |
) | |
return | |
audio_files = [] | |
audio_urls = [] | |
validated_speaker_wav = await validate_and_convert_speaker_audio(speaker_audio) | |
if not validated_speaker_wav: | |
logger.error("Invalid speaker audio after conversion, skipping TTS") | |
yield ( | |
f""" | |
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;"> | |
<h2 style="color: #d9534f;">Invalid speaker audio</h2> | |
<p style="margin-top: 20px;">Please upload a valid MP3 or WAV audio file and try again.</p> | |
</div> | |
""", | |
[] | |
) | |
return | |
for i, script in enumerate(scripts): | |
cleaned_script = clean_script_text(script) | |
audio_file = os.path.join(OUTPUT_DIR, f"slide_{i+1}.mp3") | |
script_file = os.path.join(OUTPUT_DIR, f"slide_{i+1}_script.txt") | |
try: | |
with open(script_file, "w", encoding="utf-8") as f: | |
f.write(cleaned_script or "") | |
logger.info("Saved script to %s: %s", script_file, cleaned_script) | |
except Exception as e: | |
logger.error("Error saving script to %s: %s", script_file, str(e)) | |
if not cleaned_script: | |
logger.error("Skipping audio for slide %d due to empty or invalid script", i + 1) | |
audio_files.append(None) | |
audio_urls.append(None) | |
progress = 90 + ((i + 1) / len(scripts)) * 10 | |
label = f"Generated audio for slide {i + 1}/{len(scripts)}..." | |
yield ( | |
html_with_progress(label, progress), | |
[] | |
) | |
await asyncio.sleep(0.1) | |
continue | |
max_audio_retries = 2 | |
for attempt in range(max_audio_retries + 1): | |
try: | |
current_text = cleaned_script | |
if attempt > 0: | |
sentences = re.split(r"[.!?]+", cleaned_script) | |
sentences = [s.strip() for s in sentences if s.strip()][:2] | |
current_text = ". ".join(sentences) + "." | |
logger.info("Retry %d for slide %d with simplified text: %s", attempt, i + 1, current_text) | |
success = generate_xtts_audio(tts, current_text, validated_speaker_wav, audio_file) | |
if not success: | |
raise RuntimeError("TTS generation failed") | |
logger.info("Generated audio for slide %d: %s", i + 1, audio_file) | |
audio_files.append(audio_file) | |
# Use Gradio's file serving URL | |
audio_urls.append(f"/gradio_api/file={audio_file}") | |
progress = 90 + ((i + 1) / len(scripts)) * 10 | |
label = f"Generated audio for slide {i + 1}/{len(scripts)}..." | |
yield ( | |
html_with_progress(label, progress), | |
[] | |
) | |
await asyncio.sleep(0.1) | |
break | |
except Exception as e: | |
logger.error("Error generating audio for slide %d (attempt %d): %s\n%s", i + 1, attempt, str(e), traceback.format_exc()) | |
if attempt == max_audio_retries: | |
logger.error("Max retries reached for slide %d, skipping", i + 1) | |
audio_files.append(None) | |
audio_urls.append(None) | |
progress = 90 + ((i + 1) / len(scripts)) * 10 | |
label = f"Generated audio for slide {i + 1}/{len(scripts)}..." | |
yield ( | |
html_with_progress(label, progress), | |
[] | |
) | |
await asyncio.sleep(0.1) | |
break | |
# Collect .txt files for download | |
txt_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith('.txt')] | |
txt_files.sort() # Sort for consistent display | |
txt_file_paths = [os.path.join(OUTPUT_DIR, f) for f in txt_files] | |
# Generate audio timeline with playable audio elements | |
audio_timeline = "" | |
for i, audio_url in enumerate(audio_urls): | |
if audio_url: | |
audio_timeline += f'<audio id="audio-{i+1}" controls src="{audio_url}" style="display: inline-block; margin: 0 10px; width: 200px;"></audio>' | |
else: | |
audio_timeline += f'<span id="audio-{i+1}" style="display: inline-block; margin: 0 10px;">slide_{i+1}.mp3 (not generated)</span>' | |
slides_info = json.dumps({"slides": markdown_slides, "audioFiles": audio_urls}) | |
html_output = f""" | |
<script src="https://cdn.jsdelivr.net/npm/[email protected]/marked.min.js"></script> | |
<div id="lecture-container" style="height: 700px; border: 1px solid #ddd; border-radius: 8px; display: flex; flex-direction: column; justify-content: space-between;"> | |
<div id="slide-content" style="flex: 1; overflow: auto; padding: 20px; text-align: center; background-color: #fff; color: #333;"> | |
<!-- Slides will be rendered here --> | |
</div> | |
<div style="padding: 20px; text-align: center;"> | |
<div style="display: flex; justify-content: center; margin-bottom: 10px;"> | |
{audio_timeline} | |
</div> | |
<div style="display: flex; justify-content: center; margin-bottom: 10px;"> | |
<button id="prev-btn" style="border-radius: 50%; width: 40px; height: 40px; margin: 0 5px; font-size: 1.2em; cursor: pointer;">⏮</button> | |
<button id="play-btn" style="border-radius: 50%; width: 40px; height: 40px; margin: 0 5px; font-size: 1.2em; cursor: pointer;">⏯</button> | |
<button id="next-btn" style="border-radius: 50%; width: 40px; height: 40px; margin: 0 5px; font-size: 1.2em; cursor: pointer;">⏭</button> | |
<button style="border-radius: 50%; width: 40px; height: 40px; margin: 0 5px; font-size: 1.2em; cursor: pointer;">☐</button> | |
</div> | |
</div> | |
</div> | |
<script> | |
const lectureData = {slides_info}; | |
let currentSlide = 0; | |
const totalSlides = lectureData.slides.length; | |
let audioElements = []; | |
// Populate audio elements | |
for (let i = 0; i < totalSlides; i++) {{ | |
const audio = document.getElementById(`audio-${{i+1}}`); | |
audioElements.push(audio); | |
}} | |
function renderSlide() {{ | |
const slideContent = document.getElementById('slide-content'); | |
if (lectureData.slides[currentSlide]) {{ | |
const markdownText = lectureData.slides[currentSlide]; | |
const htmlContent = marked.parse(markdownText); | |
slideContent.innerHTML = htmlContent; | |
console.log("Rendering slide:", markdownText); | |
console.log("Rendered HTML:", htmlContent); | |
}} else {{ | |
slideContent.innerHTML = '<h2>No slide content available</h2>'; | |
console.log("No slide content for index:", currentSlide); | |
}} | |
}} | |
function updateSlide() {{ | |
renderSlide(); | |
audioElements.forEach(audio => {{ | |
if (audio && audio.pause) {{ | |
audio.pause(); | |
audio.currentTime = 0; | |
}} | |
}}); | |
}} | |
function prevSlide() {{ | |
if (currentSlide > 0) {{ | |
currentSlide--; | |
updateSlide(); | |
}} | |
}} | |
function nextSlide() {{ | |
if (currentSlide < totalSlides - 1) {{ | |
currentSlide++; | |
updateSlide(); | |
}} | |
}} | |
function playAll() {{ | |
let index = 0; | |
function playNext() {{ | |
if (index >= totalSlides) return; | |
const audio = audioElements[index]; | |
if (audio && audio.play) {{ | |
audio.play().then(() => {{ | |
audio.addEventListener('ended', () => {{ | |
index++; | |
playNext(); | |
}}, {{ once: true }}); | |
}}).catch(e => {{ | |
console.error('Audio play failed:', e); | |
index++; | |
playNext(); | |
}}); | |
}} else {{ | |
index++; | |
playNext(); | |
}} | |
}} | |
playNext(); | |
}} | |
// Attach event listeners | |
document.getElementById('prev-btn').addEventListener('click', prevSlide); | |
document.getElementById('play-btn').addEventListener('click', playAll); | |
document.getElementById('next-btn').addEventListener('click', nextSlide); | |
// Initialize first slide | |
renderSlide(); | |
</script> | |
""" | |
logger.info("Lecture generation completed successfully") | |
yield ( | |
html_output, | |
txt_file_paths | |
) | |
except Exception as e: | |
logger.error("Error during lecture generation: %s\n%s", str(e), traceback.format_exc()) | |
yield ( | |
f""" | |
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;"> | |
<h2 style="color: #d9534f;">Error during lecture generation</h2> | |
<p style="margin-top: 10px; font-size: 16px;">{str(e)}</p> | |
<p style="margin-top: 20px;">Please try again or adjust your inputs.</p> | |
</div> | |
""", | |
[] | |
) | |
return | |
# Gradio interface | |
with gr.Blocks(title="Agent Feynman") as demo: | |
gr.Markdown("# <center>Learn Anything With Professor AI Feynman</center>") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Group(): | |
title = gr.Textbox(label="Lecture Title", placeholder="e.g. Introduction to AI") | |
lecture_content_description = gr.Textbox(label="Lecture Content Description", placeholder="e.g. Focus on recent advancements") | |
lecture_type = gr.Dropdown(["Conference", "University", "High school"], label="Audience", value="University") | |
api_service = gr.Dropdown( | |
choices=[ | |
"OpenAI-gpt-4o-2024-08-06", | |
"Anthropic-claude-3-sonnet-20240229", | |
"Google-gemini-1.5-flash", | |
"Ollama-llama3.2" | |
], | |
label="Model", | |
value="Google-gemini-1.5-flash" | |
) | |
api_key = gr.Textbox(label="Model Provider API Key", type="password", placeholder="Not required for Ollama") | |
serpapi_key = gr.Textbox(label="SerpApi Key", type="password", placeholder="Enter your SerpApi key (optional)") | |
num_slides = gr.Slider(1, 20, step=1, label="Number of Slides", value=3) | |
speaker_audio = gr.Audio(label="Speaker sample audio (MP3 or WAV)", type="filepath", elem_id="speaker-audio") | |
generate_btn = gr.Button("Generate Lecture") | |
with gr.Column(scale=2): | |
default_slide_html = """ | |
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;"> | |
<h2 style="font-style: italic; color: #555;">Waiting for lecture content...</h2> | |
<p style="margin-top: 10px; font-size: 16px;">Please Generate lecture content via the form on the left first before lecture begins</p> | |
</div> | |
""" | |
slide_display = gr.HTML(label="Lecture Slides", value=default_slide_html) | |
file_output = gr.File(label="Download Generated Files") | |
speaker_audio.change( | |
fn=update_audio_preview, | |
inputs=speaker_audio, | |
outputs=speaker_audio | |
) | |
generate_btn.click( | |
fn=on_generate, | |
inputs=[api_service, api_key, serpapi_key, title, lecture_content_description, lecture_type, speaker_audio, num_slides], | |
outputs=[slide_display, file_output] | |
) | |
if __name__ == "__main__": | |
demo.launch(allowed_paths=[OUTPUT_DIR], max_file_size="5mb") |