Jaward's picture
Update app.py
89a534d verified
raw
history blame
43.9 kB
# Note: For Huggingface Spaces, ensure the Dockerfile includes:
# RUN mkdir -p /tmp/cache/
# RUN chmod a+rwx -R /tmp/cache/
# ENV TRANSFORMERS_CACHE=/tmp/cache/
import os
import json
import re
import gradio as gr
import asyncio
import logging
import torch
import random
from serpapi import GoogleSearch
from pydantic import BaseModel
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import HandoffTermination, TextMentionTermination
from autogen_agentchat.teams import Swarm
from autogen_agentchat.ui import Console
from autogen_agentchat.messages import TextMessage, HandoffMessage, StructuredMessage
from autogen_ext.models.anthropic import AnthropicChatCompletionClient
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.models.ollama import OllamaChatCompletionClient
from markdown_pdf import MarkdownPdf, Section
import traceback
import soundfile as sf
import tempfile
from pydub import AudioSegment
from TTS.api import TTS
# Set up logging
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler("lecture_generation.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# Set up environment for Huggingface Spaces
OUTPUT_DIR = "/data/outputs" # Persistent storage in Huggingface Spaces
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.environ["COQUI_TOS_AGREED"] = "1"
gr.set_static_paths(paths=[OUTPUT_DIR]) # Expose OUTPUT_DIR for file access
# Define Pydantic model for slide data
class Slide(BaseModel):
title: str
content: str
class SlidesOutput(BaseModel):
slides: list[Slide]
# Define search_web tool using SerpApi
def search_web(query: str, serpapi_key: str) -> str:
try:
params = {
"q": query,
"engine": "google",
"api_key": serpapi_key,
"num": 5
}
search = GoogleSearch(params)
results = search.get_dict()
if "error" in results:
logger.error("SerpApi error: %s", results["error"])
return f"Error during search: {results['error']}"
if "organic_results" not in results or not results["organic_results"]:
logger.info("No search results found for query: %s", query)
return f"No results found for query: {query}"
formatted_results = []
for item in results["organic_results"][:5]:
title = item.get("title", "No title")
snippet = item.get("snippet", "No snippet")
link = item.get("link", "No link")
formatted_results.append(f"Title: {title}\nSnippet: {snippet}\nLink: {link}\n")
formatted_output = "\n".join(formatted_results)
logger.info("Successfully retrieved search results for query: %s", query)
return f"Search results for {query}:\n{formatted_output}"
except Exception as e:
logger.error("Unexpected error during search: %s", str(e))
return f"Unexpected error during search: {str(e)}"
# Define helper function for progress HTML
def html_with_progress(label, progress):
return f"""
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;">
<div style="width: 100%; background-color: #FFFFFF; border-radius: 10px; overflow: hidden; margin-bottom: 20px;">
<div style="width: {progress}%; height: 30px; background-color: #4CAF50; border-radius: 10px;"></div>
</div>
<h2 style="font-style: italic; color: #555;">{label}</h2>
</div>
"""
# Function to get model client based on selected service
def get_model_client(service, api_key):
if service == "OpenAI-gpt-4o-2024-08-06":
return OpenAIChatCompletionClient(model="gpt-4o-2024-08-06", api_key=api_key)
elif service == "Anthropic-claude-3-sonnet-20240229":
return AnthropicChatCompletionClient(model="claude-3-sonnet-20240229", api_key=api_key)
elif service == "Google-gemini-1.5-flash":
return OpenAIChatCompletionClient(model="gemini-1.5-flash", api_key=api_key)
elif service == "Ollama-llama3.2":
return OllamaChatCompletionClient(model="llama3.2")
else:
raise ValueError("Invalid service")
# Helper function to clean script text and make it natural
def clean_script_text(script):
if not script or not isinstance(script, str):
logger.error("Invalid script input: %s", script)
return None
# Minimal cleaning to preserve natural language
script = re.sub(r"\*\*Slide \d+:.*?\*\*", "", script) # Remove slide headers
script = re.sub(r"\[.*?\]", "", script) # Remove bracketed content
script = re.sub(r"Title:.*?\n|Content:.*?\n", "", script) # Remove metadata
script = script.replace("humanlike", "human-like").replace("problemsolving", "problem-solving")
script = re.sub(r"\s+", " ", script).strip() # Normalize whitespace
# Convert bullet points to spoken cues
script = re.sub(r"^\s*-\s*", "So, ", script, flags=re.MULTILINE)
# Add non-verbal words randomly (e.g., "um," "you know," "like")
non_verbal = ["um, ", "you know, ", "like, "]
words = script.split()
for i in range(len(words) - 1, -1, -1):
if random.random() < 0.1: # 10% chance per word
words.insert(i, random.choice(non_verbal))
script = " ".join(words)
# Basic validation
if len(script) < 10:
logger.error("Cleaned script too short (%d characters): %s", len(script), script)
return None
logger.info("Cleaned and naturalized script: %s", script)
return script
# Helper function to validate and convert speaker audio (MP3 or WAV)
async def validate_and_convert_speaker_audio(speaker_audio):
if not os.path.exists(speaker_audio):
logger.error("Speaker audio file does not exist: %s", speaker_audio)
return None
try:
# Check file extension
ext = os.path.splitext(speaker_audio)[1].lower()
if ext == ".mp3":
logger.info("Converting MP3 to WAV: %s", speaker_audio)
audio = AudioSegment.from_mp3(speaker_audio)
# Convert to mono, 22050 Hz
audio = audio.set_channels(1).set_frame_rate(22050)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
audio.export(temp_file.name, format="wav")
speaker_wav = temp_file.name
elif ext == ".wav":
speaker_wav = speaker_audio
else:
logger.error("Unsupported audio format: %s", ext)
return None
# Validate WAV file
data, samplerate = sf.read(speaker_wav)
if samplerate < 16000 or samplerate > 48000:
logger.error("Invalid sample rate for %s: %d Hz", speaker_wav, samplerate)
return None
if len(data) < 16000:
logger.error("Speaker audio too short: %d frames", len(data))
return None
if data.ndim == 2:
logger.info("Converting stereo WAV to mono: %s", speaker_wav)
data = data.mean(axis=1)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
sf.write(temp_file.name, data, samplerate)
speaker_wav = temp_file.name
logger.info("Validated speaker audio: %s", speaker_wav)
return speaker_wav
except Exception as e:
logger.error("Failed to validate or convert speaker audio %s: %s", speaker_audio, str(e))
return None
# Helper function to generate audio using Coqui TTS API
def generate_xtts_audio(tts, text, speaker_wav, output_path):
if not tts:
logger.error("TTS model not initialized")
return False
try:
tts.tts_to_file(text=text, speaker_wav=speaker_wav, language="en", file_path=output_path)
logger.info("Generated audio for %s", output_path)
return True
except Exception as e:
logger.error("Failed to generate audio for %s: %s", output_path, str(e))
return False
# Helper function to extract JSON from messages
def extract_json_from_message(message):
if isinstance(message, TextMessage):
content = message.content
logger.debug("Extracting JSON from TextMessage: %s", content)
if not isinstance(content, str):
logger.warning("TextMessage content is not a string: %s", content)
return None
# Try standard JSON block
pattern = r"```json\s*(.*?)\s*```"
match = re.search(pattern, content, re.DOTALL)
if match:
try:
parsed = json.loads(match.group(1))
logger.info("Parsed JSON from TextMessage: %s", parsed)
return parsed
except json.JSONDecodeError as e:
logger.error("Failed to parse JSON from TextMessage: %s, Content: %s", e, content)
# Fallback: Try raw JSON array
json_pattern = r"\[\s*\{.*?\}\s*\]"
match = re.search(json_pattern, content, re.DOTALL)
if match:
try:
parsed = json.loads(match.group(0))
logger.info("Parsed fallback JSON from TextMessage: %s", parsed)
return parsed
except json.JSONDecodeError as e:
logger.error("Failed to parse fallback JSON from TextMessage: %s, Content: %s", e, content)
# Fallback: Try any JSON-like structure
try:
parsed = json.loads(content)
if isinstance(parsed, (list, dict)):
logger.info("Parsed JSON from raw content: %s", parsed)
return parsed
except json.JSONDecodeError:
pass
logger.warning("No JSON found in TextMessage content: %s", content)
return None
elif isinstance(message, StructuredMessage):
content = message.content
logger.debug("Extracting JSON from StructuredMessage: %s", content)
try:
if isinstance(content, BaseModel):
content_dict = content.dict()
return content_dict.get("slides", content_dict)
return content
except Exception as e:
logger.error("Failed to extract JSON from StructuredMessage: %s, Content: %s", e, content)
return None
elif isinstance(message, HandoffMessage):
logger.debug("Extracting JSON from HandoffMessage context")
for ctx_msg in message.context:
if hasattr(ctx_msg, "content"):
content = ctx_msg.content
logger.debug("Handoff context message content: %s", content)
if isinstance(content, str):
pattern = r"```json\s*(.*?)\s*```"
match = re.search(pattern, content, re.DOTALL)
if match:
try:
parsed = json.loads(match.group(1))
logger.info("Parsed JSON from HandoffMessage context: %s", parsed)
return parsed
except json.JSONDecodeError as e:
logger.error("Failed to parse JSON from HandoffMessage context: %s, Content: %s", e, content)
json_pattern = r"\[\s*\{.*?\}\s*\]"
match = re.search(json_pattern, content, re.DOTALL)
if match:
try:
parsed = json.loads(match.group(0))
logger.info("Parsed fallback JSON from HandoffMessage context: %s", parsed)
return parsed
except json.JSONDecodeError as e:
logger.error("Failed to parse fallback JSON from HandoffMessage context: %s, Content: %s", e, content)
try:
parsed = json.loads(content)
if isinstance(parsed, (list, dict)):
logger.info("Parsed JSON from raw HandoffMessage context: %s", parsed)
return parsed
except json.JSONDecodeError:
pass
elif isinstance(content, dict):
return content.get("slides", content)
logger.warning("No JSON found in HandoffMessage context")
return None
logger.warning("Unsupported message type for JSON extraction: %s", type(message))
return None
# Function to generate Markdown and convert to PDF (portrait, centered)
def generate_slides_pdf(slides):
pdf = MarkdownPdf()
for slide in slides:
content_lines = slide['content'].replace('\n', '\n\n')
markdown_content = f"""
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; text-align: center; padding: 20px;">
# {slide['title']}
*Prof. AI Feynman*
*Princeton University, April 26th, 2025*
{content_lines}
</div>
---
"""
pdf.add_section(Section(markdown_content, toc=False))
pdf_file = os.path.join(OUTPUT_DIR, "slides.pdf")
try:
pdf.save(pdf_file)
logger.info("Generated PDF slides (portrait): %s", pdf_file)
return pdf_file
except Exception as e:
logger.error("Failed to generate PDF: %s", str(e))
raise
# Async function to update audio preview
async def update_audio_preview(audio_file):
if audio_file:
logger.info("Updating audio preview for file: %s", audio_file)
return audio_file
return None
# Async function to generate lecture materials and audio
async def on_generate(api_service, api_key, serpapi_key, title, topic, instructions, lecture_type, speaker_audio, num_slides):
if not serpapi_key:
yield f"""
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;">
<h2 style="color: #d9534f;">SerpApi key required</h2>
<p style="margin-top: 20px;">Please provide a valid SerpApi key and try again.</p>
</div>
"""
return
# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)
logger.info("Output directory set to: %s", OUTPUT_DIR)
# Initialize TTS model
tts = None
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
logger.info("TTS model initialized on %s", device)
except Exception as e:
logger.error("Failed to initialize TTS model: %s", str(e))
yield f"""
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;">
<h2 style="color: #d9534f;">TTS model initialization failed</h2>
<p style="margin-top: 20px;">Error: {str(e)}</p>
<p>Please ensure the Coqui TTS model is properly installed and try again.</p>
</div>
"""
return
model_client = get_model_client(api_service, api_key)
research_agent = AssistantAgent(
name="research_agent",
model_client=model_client,
handoffs=["slide_agent"],
system_message="You are a Research Agent. Use the search_web tool to gather information on the topic and keywords from the initial message. Summarize the findings concisely in a single message, then use the handoff_to_slide_agent tool to pass the task to the Slide Agent. Do not produce any other output.",
tools=[search_web]
)
slide_agent = AssistantAgent(
name="slide_agent",
model_client=model_client,
handoffs=["script_agent"],
system_message=f"""
You are a Slide Agent. Using the research from the conversation history, generate EXACTLY {num_slides} content slides, plus 1 quiz slide, 1 assignment slide, and 1 thank-you slide, for a TOTAL of {num_slides + 3} slides. Output ONLY a JSON array wrapped in ```json ... ``` in a TextMessage, with each slide as an object with 'title' and 'content' keys. Ensure the JSON is valid, contains EXACTLY {num_slides + 3} slides, and matches the specified count before proceeding. Do not include explanatory text, comments, or other messages. After outputting, use the handoff_to_script_agent tool.
Example for 2 content slides:
```json
[
{{"title": "Slide 1", "content": "Content for slide 1"}},
{{"title": "Slide 2", "content": "Content for slide 2"}},
{{"title": "Quiz", "content": "Quiz questions"}},
{{"title": "Assignment", "content": "Assignment details"}},
{{"title": "Thank You", "content": "Thank you message"}}
]
```""",
output_content_type=None,
reflect_on_tool_use=False
)
script_agent = AssistantAgent(
name="script_agent",
model_client=model_client,
handoffs=["feynman_agent"],
system_message=f"""
You are a Script Agent. Access the JSON array of {num_slides + 3} slides from the conversation history. Generate a narration script (1-2 sentences) for each of the {num_slides + 3} slides, summarizing its content in a natural, conversational tone as a speaker would, including occasional non-verbal words (e.g., "um," "you know," "like"). Output ONLY a JSON array wrapped in ```json ... ``` with exactly {num_slides + 3} strings, one script per slide, in the same order. Ensure the JSON is valid and complete. After outputting, use the handoff_to_feynman_agent tool. If scripts cannot be generated, retry once.
Example for 1 content slide:
```json
[
"So, this slide, um, covers the main topic in a fun way.",
"Alright, you know, answer these quiz questions.",
"Here's your, like, assignment to complete.",
"Thanks for, um, attending today!"
]
```""",
output_content_type=None,
reflect_on_tool_use=False
)
feynman_agent = AssistantAgent(
name="feynman_agent",
model_client=model_client,
handoffs=[],
system_message=f"""
You are Agent Feynman. Review the slides and scripts from the conversation history to ensure coherence, completeness, and that EXACTLY {num_slides + 3} slides and {num_slides + 3} scripts are received. Output a confirmation message summarizing the number of slides and scripts received. If slides or scripts are missing, invalid, or do not match the expected count ({num_slides + 3}), report the issue clearly. Use 'TERMINATE' to signal completion.
Example: 'Received {num_slides + 3} slides and {num_slides + 3} scripts. Lecture is coherent. TERMINATE'
""")
swarm = Swarm(
participants=[research_agent, slide_agent, script_agent, feynman_agent],
termination_condition=HandoffTermination(target="user") | TextMentionTermination("TERMINATE")
)
progress = 0
label = "Research: in progress..."
yield html_with_progress(label, progress)
await asyncio.sleep(0.1)
initial_message = f"""
Lecture Title: {title}
Topic: {topic}
Additional Instructions: {instructions}
Audience: {lecture_type}
Number of Content Slides: {num_slides}
Please start by researching the topic.
"""
logger.info("Starting lecture generation for topic: %s", topic)
slides = None
scripts = None
error_html = """
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;">
<h2 style="color: #d9534f;">Failed to generate lecture materials</h2>
<p style="margin-top: 20px;">Please try again with different parameters or a different model.</p>
</div>
"""
try:
max_slide_retries = 2
slide_retry_count = 0
while slide_retry_count <= max_slide_retries:
logger.info("Slide generation attempt %d/%d", slide_retry_count + 1, max_slide_retries)
task_result = await Console(swarm.run_stream(task=initial_message))
logger.info("Swarm execution completed")
script_retry_count = 0
max_script_retries = 2
for message in task_result.messages:
source = getattr(message, 'source', getattr(message, 'sender', None))
logger.debug("Processing message from %s, type: %s, content: %s", source, type(message), message.to_text() if hasattr(message, 'to_text') else str(message))
if isinstance(message, HandoffMessage):
logger.info("Handoff from %s to %s", source, message.target)
if source == "research_agent" and message.target == "slide_agent":
progress = 25
label = "Slides: generating..."
yield html_with_progress(label, progress)
await asyncio.sleep(0.1)
elif source == "slide_agent" and message.target == "script_agent":
if slides is None:
logger.warning("Slide Agent handoff without slides JSON")
extracted_json = extract_json_from_message(message)
if extracted_json:
slides = extracted_json
logger.info("Extracted slides JSON from HandoffMessage context: %s", slides)
if slides is None:
label = "Slides: failed to generate..."
yield html_with_progress(label, progress)
await asyncio.sleep(0.1)
progress = 50
label = "Scripts: generating..."
yield html_with_progress(label, progress)
await asyncio.sleep(0.1)
elif source == "script_agent" and message.target == "feynman_agent":
if scripts is None:
logger.warning("Script Agent handoff without scripts JSON")
extracted_json = extract_json_from_message(message)
if extracted_json:
scripts = extracted_json
logger.info("Extracted scripts JSON from HandoffMessage context: %s", scripts)
progress = 75
label = "Review: in progress..."
yield html_with_progress(label, progress)
await asyncio.sleep(0.1)
elif source == "research_agent" and isinstance(message, TextMessage) and "handoff_to_slide_agent" in message.content:
logger.info("Research Agent completed research")
progress = 25
label = "Slides: generating..."
yield html_with_progress(label, progress)
await asyncio.sleep(0.1)
elif source == "slide_agent" and isinstance(message, (TextMessage, StructuredMessage)):
logger.debug("Slide Agent message received: %s", message.to_text())
extracted_json = extract_json_from_message(message)
if extracted_json:
slides = extracted_json
logger.info("Slide Agent generated %d slides: %s", len(slides), slides)
# Save slide content to individual files
for i, slide in enumerate(slides):
content_file = os.path.join(OUTPUT_DIR, f"slide_{i+1}_content.txt")
try:
with open(content_file, "w", encoding="utf-8") as f:
f.write(slide["content"])
logger.info("Saved slide content to %s: %s", content_file, slide["content"])
except Exception as e:
logger.error("Error saving slide content to %s: %s", content_file, str(e))
progress = 50
label = "Scripts: generating..."
yield html_with_progress(label, progress)
await asyncio.sleep(0.1)
else:
logger.warning("No JSON extracted from slide_agent message: %s", message.to_text())
elif source == "script_agent" and isinstance(message, (TextMessage, StructuredMessage)):
logger.debug("Script Agent message received: %s", message.to_text())
extracted_json = extract_json_from_message(message)
if extracted_json:
scripts = extracted_json
logger.info("Script Agent generated scripts for %d slides: %s", len(scripts), scripts)
# Save raw scripts to individual files
for i, script in enumerate(scripts):
script_file = os.path.join(OUTPUT_DIR, f"slide_{i+1}_raw_script.txt")
try:
with open(script_file, "w", encoding="utf-8") as f:
f.write(script)
logger.info("Saved raw script to %s: %s", script_file, script)
except Exception as e:
logger.error("Error saving raw script to %s: %s", script_file, str(e))
progress = 75
label = "Scripts generated and saved. Reviewing..."
yield html_with_progress(label, progress)
await asyncio.sleep(0.1)
else:
logger.warning("No JSON extracted from script_agent message: %s", message.to_text())
if script_retry_count < max_script_retries:
script_retry_count += 1
logger.info("Retrying script generation (attempt %d/%d)", script_retry_count, max_script_retries)
retry_message = TextMessage(
content="Please generate scripts for the slides as per your instructions.",
source="user",
recipient="script_agent"
)
task_result.messages.append(retry_message)
continue
elif source == "feynman_agent" and isinstance(message, TextMessage) and "TERMINATE" in message.content:
logger.info("Feynman Agent completed lecture review: %s", message.content)
progress = 90
label = "Lecture materials ready. Generating audio..."
yield html_with_progress(label, progress)
await asyncio.sleep(0.1)
# Validate slide count
expected_slide_count = num_slides + 3
if slides and len(slides) == expected_slide_count:
logger.info("Slide count validated: %d slides received", len(slides))
break
else:
logger.warning("Incorrect slide count: expected %d, got %d", expected_slide_count, len(slides) if slides else 0)
slide_retry_count += 1
slides = None
if slide_retry_count <= max_slide_retries:
logger.info("Retrying slide generation (attempt %d/%d)", slide_retry_count + 1, max_slide_retries)
task_result = await Console(swarm.run_stream(task=initial_message))
else:
logger.error("Max slide retries reached")
yield f"""
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;">
<h2 style="color: #d9534f;">Incorrect number of slides</h2>
<p style="margin-top: 20px;">Expected {expected_slide_count} slides ({num_slides} content slides + quiz, assignment, thank-you), but generated {len(slides) if slides else 0}. Please try again with a different model.</p>
</div>
"""
return
logger.info("Slides state: %s", "Generated" if slides else "None")
logger.info("Scripts state: %s", "Generated" if scripts else "None")
if not slides or not scripts:
error_message = f"Failed to generate {'slides and scripts' if not slides and not scripts else 'slides' if not slides else 'scripts'}"
error_message += f". Received {len(slides) if slides else 0} slides and {len(scripts) if scripts else 0} scripts."
logger.error("%s", error_message)
logger.debug("Dumping all messages for debugging:")
for msg in task_result.messages:
source = getattr(msg, 'source', getattr(msg, 'sender', None))
logger.debug("Message from %s, type: %s, content: %s", source, type(msg), msg.to_text() if hasattr(msg, 'to_text') else str(msg))
yield error_html
return
if not isinstance(scripts, list) or not all(isinstance(s, str) for s in scripts):
logger.error("Scripts are not a list of strings: %s", scripts)
yield f"""
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;">
<h2 style="color: #d9534f;">Invalid script format</h2>
<p style="margin-top: 20px;">Scripts must be a list of strings. Please try again.</p>
</div>
"""
return
if len(scripts) != expected_slide_count:
logger.error("Mismatch between number of slides (%d) and scripts (%d)", len(slides), len(scripts))
yield f"""
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;">
<h2 style="color: #d9534f;">Mismatch in slides and scripts</h2>
<p style="margin-top: 20px;">Generated {len(slides)} slides but {len(scripts)} scripts. Please try again.</p>
</div>
"""
return
# Generate PDF from slides
try:
pdf_file = generate_slides_pdf(slides)
except Exception as e:
logger.error("PDF generation failed: %s", str(e))
yield f"""
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;">
<h2 style="color: #d9534f;">PDF generation failed</h2>
<p style="margin-top: 20px;">Error: {str(e)}</p>
<p>Please try again or check the lecture_generation.log for details.</p>
</div>
"""
return
audio_files = []
speaker_audio = speaker_audio if speaker_audio else "feynman.mp3"
validated_speaker_wav = await validate_and_convert_speaker_audio(speaker_audio)
if not validated_speaker_wav:
logger.error("Invalid speaker audio after conversion, skipping TTS")
yield f"""
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;">
<h2 style="color: #d9534f;">Invalid speaker audio</h2>
<p style="margin-top: 20px;">Please upload a valid MP3 or WAV audio file and try again.</p>
</div>
"""
return
# Process audio generation sequentially with retries
for i, script in enumerate(scripts):
cleaned_script = clean_script_text(script)
audio_file = os.path.join(OUTPUT_DIR, f"slide_{i+1}.wav")
script_file = os.path.join(OUTPUT_DIR, f"slide_{i+1}_script.txt")
# Save cleaned script
try:
with open(script_file, "w", encoding="utf-8") as f:
f.write(cleaned_script or "")
logger.info("Saved cleaned script to %s: %s", script_file, cleaned_script)
except Exception as e:
logger.error("Error saving cleaned script to %s: %s", script_file, str(e))
if not cleaned_script:
logger.error("Skipping audio for slide %d due to empty or invalid script", i + 1)
audio_files.append(None)
progress = 90 + ((i + 1) / len(scripts)) * 10
label = f"Generated audio for slide {i + 1}/{len(scripts)}..."
yield html_with_progress(label, progress)
await asyncio.sleep(0.1)
continue
max_retries = 2
for attempt in range(max_retries + 1):
try:
current_text = cleaned_script
if attempt > 0:
sentences = re.split(r"[.!?]+", cleaned_script)
sentences = [s.strip() for s in sentences if s.strip()][:2]
current_text = ". ".join(sentences) + "."
logger.info("Retry %d for slide %d with simplified text: %s", attempt, i + 1, current_text)
success = generate_xtts_audio(tts, current_text, validated_speaker_wav, audio_file)
if not success:
raise RuntimeError("TTS generation failed")
logger.info("Generated audio for slide %d: %s", i + 1, audio_file)
audio_files.append(audio_file)
progress = 90 + ((i + 1) / len(scripts)) * 10
label = f"Generated audio for slide {i + 1}/{len(scripts)}..."
yield html_with_progress(label, progress)
await asyncio.sleep(0.1)
break
except Exception as e:
logger.error("Error generating audio for slide %d (attempt %d): %s\n%s", i + 1, attempt, str(e), traceback.format_exc())
if attempt == max_retries:
logger.error("Max retries reached for slide %d, skipping", i + 1)
audio_files.append(None)
progress = 90 + ((i + 1) / len(scripts)) * 10
label = f"Generated audio for slide {i + 1}/{len(scripts)}..."
yield html_with_progress(label, progress)
await asyncio.sleep(0.1)
break
# Prepare output HTML with gr.File for PDF and gr.FileExplorer for outputs
slides_info = json.dumps({"slides": [
{"title": slide["title"], "content": slide["content"]}
for slide in slides
], "audioFiles": audio_files})
html_output = f"""
<div id="lecture-container" style="height: 700px; border: 1px solid #ddd; border-radius: 8px; display: flex; flex-direction: column; justify-content: space-between; padding: 20px;">
<div style="flex: 1; overflow: auto;">
<h3>Lecture Slides</h3>
<p>Download or view the slides PDF below (opens in your browser's PDF viewer):</p>
<gradio-file value="{pdf_file}" label="Slides PDF" file_types=[".pdf"]></gradio-file>
<h3>Generated Files</h3>
<p>Explore all generated files (PDF, audio, scripts) in the output directory:</p>
<gradio-file-explorer glob="/data/outputs/*" label="Output Directory"></gradio-file-explorer>
</div>
<div style="padding: 20px;">
<div id="progress-bar" style="width: 100%; height: 5px; background-color: #ddd; border-radius: 2px; margin-bottom: 10px;">
<div id="progress-fill" style="width: {(1/len(slides)*100)}%; height: 100%; background-color: #4CAF50; border-radius: 2px;"></div>
</div>
<div style="display: flex; justify-content: center; margin-bottom: 10px;">
<button onclick="prevSlide()" style="border-radius: 50%; width: 40px; height: 40px; margin: 0 5px; font-size: 1.2em; cursor: pointer;">⏮</button>
<button onclick="togglePlay()" style="border-radius: 50%; width: 40px; height: 40px; margin: 0 5px; font-size: 1.2em; cursor: pointer;">⏯</button>
<button onclick="nextSlide()" style="border-radius: 50%; width: 40px; height: 40px; margin: 0 5px; font-size: 1.2em; cursor: pointer;">⏭</button>
</div>
<p id="slide-counter" style="text-align: center;">Slide 1 of {len(slides)}</p>
</div>
</div>
<script>
const lectureData = {slides_info};
let currentSlide = 0;
const totalSlides = lectureData.slides.length;
const slideCounter = document.getElementById('slide-counter');
const progressFill = document.getElementById('progress-fill');
let audioElements = [];
let currentAudio = null;
for (let i = 0; i < totalSlides; i++) {{
if (lectureData.audioFiles && lectureData.audioFiles[i]) {{
const audio = new Audio('/gradio_api/file=' + lectureData.audioFiles[i]);
audioElements.push(audio);
}} else {{
audioElements.push(null);
}}
}}
function updateSlide() {{
slideCounter.textContent = `Slide ${{currentSlide + 1}} of ${{totalSlides}}`;
progressFill.style.width = `${{(currentSlide + 1) / totalSlides * 100}}%`;
if (currentAudio) {{
currentAudio.pause();
currentAudio.currentTime = 0;
}}
if (audioElements[currentSlide]) {{
currentAudio = audioElements[currentSlide];
currentAudio.play().catch(e => console.error('Audio play failed:', e));
}} else {{
currentAudio = null;
}}
}}
function prevSlide() {{
if (currentSlide > 0) {{
currentSlide--;
updateSlide();
}}
}}
function nextSlide() {{
if (currentSlide < totalSlides - 1) {{
currentSlide++;
updateSlide();
}}
}}
function togglePlay() {{
if (!audioElements[currentSlide]) return;
if (currentAudio.paused) {{
currentAudio.play().catch(e => console.error('Audio play failed:', e));
}} else {{
currentAudio.pause();
}}
}}
audioElements.forEach((audio, index) => {{
if (audio) {{
audio.addEventListener('ended', () => {{
if (index < totalSlides - 1) {{
nextSlide();
}}
}});
}}
}});
</script>
"""
logger.info("Lecture generation completed successfully")
yield html_output
except Exception as e:
logger.error("Error during lecture generation: %s\n%s", str(e), traceback.format_exc())
yield f"""
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;">
<h2 style="color: #d9534f;">Error during lecture generation</h2>
<p style="margin-top: 10px; font-size: 16px;">{str(e)}</p>
<p style="margin-top: 20px;">Please try again or check the lecture_generation.log for details.</p>
</div>
"""
return
# Gradio interface
with gr.Blocks(title="Agent Feynman") as demo:
gr.Markdown("# <center>Learn Anything With Professor AI Feynman</center>")
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
title = gr.Textbox(label="Lecture Title", placeholder="e.g. Introduction to AI")
topic = gr.Textbox(label="Topic", placeholder="e.g. Artificial Intelligence")
instructions = gr.Textbox(label="Additional Instructions", placeholder="e.g. Focus on recent advancements")
lecture_type = gr.Dropdown(["Conference", "University", "High school"], label="Audience", value="University")
api_service = gr.Dropdown(
choices=[
"OpenAI-gpt-4o-2024-08-06",
"Anthropic-claude-3-sonnet-20240229",
"Google-gemini-1.5-flash",
"Ollama-llama3.2"
],
label="Model",
value="Google-gemini-1.5-flash"
)
api_key = gr.Textbox(label="Model Provider API Key", type="password", placeholder="Not required for Ollama")
serpapi_key = gr.Textbox(label="SerpApi Key", type="password", placeholder="Enter your SerpApi key")
num_slides = gr.Slider(1, 20, step=1, label="Number of Content Slides", value=3)
speaker_audio = gr.Audio(label="Speaker sample audio (MP3 or WAV)", type="filepath", elem_id="speaker-audio")
generate_btn = gr.Button("Generate Lecture")
with gr.Column(scale=2):
default_slide_html = """
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;">
<h2 style="font-style: italic; color: #555;">Waiting for lecture content...</h2>
<p style="margin-top: 10px; font-size: 16px;">Please Generate lecture content via the form on the left first before lecture begins</p>
</div>
"""
slide_display = gr.HTML(label="Lecture Slides", value=default_slide_html)
speaker_audio.change(
fn=update_audio_preview,
inputs=speaker_audio,
outputs=speaker_audio
)
generate_btn.click(
fn=on_generate,
inputs=[api_service, api_key, serpapi_key, title, topic, instructions, lecture_type, speaker_audio, num_slides],
outputs=[slide_display]
)
if __name__ == "__main__":
demo.launch()