MangakAI / manga.py
Shiva4113's picture
Update manga.py
5e92350 verified
import os
import PIL
from io import BytesIO
from google import genai
from dotenv import load_dotenv
import pathlib
from utils import (
get_scene_splitting_prompt,
get_panel_prompt,
SCENE_BREAK_DELIMITER,
get_regeneration_prompt
)
import time
from PIL import Image
from reportlab.pdfgen import canvas
from reportlab.lib.pagesizes import A4
import shutil
import uuid
import hashlib
from typing import Dict, Optional
# Load environment variables
load_dotenv()
class UserSession:
"""Represents a user session with isolated data and API key."""
def __init__(self, session_id: str, api_key: str):
self.session_id = session_id
self.api_key_hash = self._hash_api_key(api_key) # Store only hash for security
self._api_key = api_key # Keep encrypted or in memory only
self.created_at = time.time()
self.last_activity = time.time()
# Session-specific directories
self.base_dir = os.path.join("data", "sessions", session_id)
self.output_dir = os.path.join(self.base_dir, "output")
self.user_templates_dir = os.path.join(self.base_dir, "templates")
self.user_references_dir = os.path.join(self.base_dir, "references")
# Create session directories
os.makedirs(self.output_dir, exist_ok=True)
os.makedirs(self.user_templates_dir, exist_ok=True)
os.makedirs(self.user_references_dir, exist_ok=True)
# Session generation data
self.current_generation = {
'scenes': [],
'generated_images': [],
'chat': None,
'user_preferences': {},
'current_template_path': os.getenv("TEMPLATE_PATH", "data/templates/template.png")
}
def _hash_api_key(self, api_key: str) -> str:
"""Create a hash of the API key for identification (not storage)."""
return hashlib.sha256(api_key.encode()).hexdigest()[:16]
def get_api_key(self) -> str:
"""Get the API key (should be implemented with proper encryption in production)."""
return self._api_key
def update_activity(self):
"""Update last activity timestamp."""
self.last_activity = time.time()
def is_expired(self, timeout_hours: int = 24) -> bool:
"""Check if session has expired."""
return (time.time() - self.last_activity) > (timeout_hours * 3600)
def cleanup(self):
"""Clean up session files and data."""
try:
if os.path.exists(self.base_dir):
shutil.rmtree(self.base_dir)
print(f"Cleaned up session: {self.session_id}")
except Exception as e:
print(f"Error cleaning up session {self.session_id}: {e}")
class SessionManager:
"""Manages user sessions and provides isolation between users."""
def __init__(self):
self.sessions: Dict[str, UserSession] = {}
self.session_timeout_hours = 4
def create_session(self, api_key: str) -> str:
"""Create a new session for a user with their API key."""
session_id = str(uuid.uuid4())
self.sessions[session_id] = UserSession(session_id, api_key)
self._cleanup_expired_sessions()
return session_id
def get_session(self, session_id: str) -> Optional[UserSession]:
"""Get a session by ID, return None if not found or expired."""
if session_id not in self.sessions:
return None
session = self.sessions[session_id]
if session.is_expired(self.session_timeout_hours):
self._remove_session(session_id)
return None
session.update_activity()
return session
def _remove_session(self, session_id: str):
"""Remove and cleanup a session."""
if session_id in self.sessions:
self.sessions[session_id].cleanup()
del self.sessions[session_id]
def _cleanup_expired_sessions(self):
"""Clean up all expired sessions."""
expired_sessions = [
sid for sid, session in self.sessions.items()
if session.is_expired(self.session_timeout_hours)
]
for session_id in expired_sessions:
self._remove_session(session_id)
class MangaGenerator:
def __init__(self, session: UserSession):
"""Initialize the manga generator for a specific user session."""
self.session = session
self.api_key = session.get_api_key()
# Use session-specific directories
self.template_path = os.getenv("TEMPLATE_PATH", "data/templates/template.png")
self.output_dir = session.output_dir
self.user_templates_dir = session.user_templates_dir
self.user_references_dir = session.user_references_dir
self.image_gen_model_name = os.getenv("IMAGE_MODEL_NAME", "gemini-2.5-flash-image-preview")
self.scene_gen_model_name = os.getenv("SCENE_MODEL_NAME", "gemini-2.0-flash")
# Initialize clients with session's API key
try:
self.image_gen_client = genai.Client(api_key=self.api_key)
self.scene_gen_client = genai.Client(api_key=self.api_key)
except Exception as e:
raise ValueError(f"Invalid API key or client initialization failed: {e}")
@property
def current_generation(self):
"""Access current generation data from session."""
return self.session.current_generation
def save_user_template(self, uploaded_file):
"""Save user uploaded template and return the path."""
if uploaded_file is None:
return None
try:
# Generate unique filename with timestamp
timestamp = int(time.time())
filename = f"user_template_{timestamp}.png"
template_path = os.path.join(self.user_templates_dir, filename)
# Handle different file input types
if hasattr(uploaded_file, 'name'): # Gradio file object
shutil.copy2(uploaded_file.name, template_path)
else:
if isinstance(uploaded_file, str):
shutil.copy2(uploaded_file, template_path)
elif hasattr(uploaded_file, 'save'): # PIL Image
uploaded_file.save(template_path)
# Verify the template was saved and is a valid image
test_image = PIL.Image.open(template_path)
test_image.verify()
print(f"User template saved to: {template_path}")
return template_path
except Exception as e:
print(f"Error saving user template: {e}")
return None
def save_user_reference_image(self, uploaded_file):
"""Save user uploaded reference image and return the path."""
if uploaded_file is None:
return None
try:
timestamp = int(time.time())
filename = f"user_reference_{timestamp}.png"
reference_path = os.path.join(self.user_references_dir, filename)
if hasattr(uploaded_file, 'name'): # Gradio file object
shutil.copy2(uploaded_file.name, reference_path)
else:
if isinstance(uploaded_file, str):
shutil.copy2(uploaded_file, reference_path)
elif hasattr(uploaded_file, 'save'): # PIL Image
uploaded_file.save(reference_path)
test_image = PIL.Image.open(reference_path)
test_image.verify()
print(f"User reference image saved to: {reference_path}")
return reference_path
except Exception as e:
print(f"Error saving user reference image: {e}")
return None
def set_template_for_generation(self, template_path):
"""Set the template to use for the current generation session."""
if template_path and os.path.exists(template_path):
self.current_generation['current_template_path'] = template_path
return True
return False
def get_current_template_path(self):
"""Get the current template path being used."""
return self.current_generation.get('current_template_path', self.template_path)
def read_story(self, file_path):
"""Read story text from file."""
with open(file_path, "r") as f:
return f.read()
def split_into_scenes(self, story_text: str, n_scenes: int):
"""Split story into visual scenes with descriptions."""
prompt = get_scene_splitting_prompt(story_text, n_scenes)
response = self.scene_gen_client.models.generate_content(
model=self.scene_gen_model_name,
contents=[prompt]
)
full_response_text = ""
for part in response.candidates[0].content.parts:
if part.text:
full_response_text += part.text
scenes = [scene.strip() for scene in full_response_text.split(SCENE_BREAK_DELIMITER)]
scenes = [scene for scene in scenes if scene]
return scenes[:n_scenes]
def save_image(self, response, path):
"""Save the generated image from response."""
time.sleep(3)
for part in response.parts:
if image := part.as_image():
image.save(path)
return image
return None
def generate_image_for_scene(self, scene_description: str, output_path: str):
"""Generate image for a single scene."""
current_template = self.get_current_template_path()
response = self.image_gen_client.models.generate_content(
model=self.image_gen_model_name,
contents=[
scene_description,
PIL.Image.open(current_template)
]
)
saved_image = self.save_image(response, output_path)
return response, saved_image
def generate_image_with_chat(self, scene_description: str, output_path: str, chat):
"""Generate image using chat context for consistency."""
current_template = self.get_current_template_path()
response = chat.send_message([
scene_description,
PIL.Image.open(current_template)
])
saved_image = self.save_image(response, output_path)
return response, saved_image
def generate_image_with_chat_and_reference(self, scene_description: str, output_path: str, chat, reference_image_path=None):
"""Generate image using chat context with optional reference image."""
current_template = self.get_current_template_path()
content = [scene_description, PIL.Image.open(current_template)]
if reference_image_path and os.path.exists(reference_image_path):
content.append(PIL.Image.open(reference_image_path))
print(f"Using reference image: {reference_image_path}")
response = chat.send_message(content)
saved_image = self.save_image(response, output_path)
return response, saved_image
def regenerate_specific_panel(self, panel_index: int, modification_request: str, reference_image=None):
"""Regenerate a specific panel with modifications and optional reference image."""
if not self.current_generation['scenes'] or not self.current_generation['chat']:
raise ValueError("No active generation session. Please generate manga first.")
if panel_index >= len(self.current_generation['scenes']):
raise ValueError(f"Panel index {panel_index} is out of range.")
original_scene = self.current_generation['scenes'][panel_index]
reference_image_path = None
if reference_image is not None:
reference_image_path = self.save_user_reference_image(reference_image)
if reference_image_path:
modification_request += "\n\nIMPORTANT: Use the provided reference image as visual guidance for style, composition, or specific elements while maintaining the story's integrity."
user_preferences = self.current_generation.get('user_preferences', {})
modified_prompt = get_regeneration_prompt(
original_scene,
modification_request,
is_first_panel=(panel_index == 0),
user_preferences=user_preferences
)
current_version = self.current_generation['generated_images'][panel_index].get('version', 1)
output_path = os.path.join(self.output_dir, f"scene{panel_index+1}_v{current_version + 1}.png")
response, saved_image = self.generate_image_with_chat_and_reference(
modified_prompt,
output_path,
self.current_generation['chat'],
reference_image_path
)
return output_path, saved_image
def replace_panel(self, panel_index: int, new_image_path: str, new_image: PIL.Image):
"""Replace a panel in the current generation."""
if not self.current_generation['generated_images']:
raise ValueError("No active generation session.")
if panel_index >= len(self.current_generation['generated_images']):
raise ValueError(f"Panel index {panel_index} is out of range.")
current_version = self.current_generation['generated_images'][panel_index].get('version', 1)
self.current_generation['generated_images'][panel_index].update({
'image_path': new_image_path,
'image': new_image,
'version': current_version + 1
})
def get_current_gallery_paths(self):
"""Get current image paths for gallery display."""
if not self.current_generation['generated_images']:
return []
return [img['image_path'] for img in self.current_generation['generated_images']]
def generate_manga_from_story(self, story_text: str, n_scenes: int = 5, user_preferences: dict = None, user_template=None):
"""Generate complete manga from story text with user preferences and optional custom template."""
if user_template is not None:
template_path = self.save_user_template(user_template)
if template_path:
self.set_template_for_generation(template_path)
print(f"Using user template: {template_path}")
else:
print("Failed to save user template, using default template")
if user_preferences is None:
user_preferences = {}
scenes = self.split_into_scenes(story_text, n_scenes)
chat = self.image_gen_client.chats.create(model=self.image_gen_model_name)
generated_images = []
responses = []
for i, scene in enumerate(scenes):
scene_description = get_panel_prompt(scene, is_first_panel=(i == 0), user_preferences=user_preferences)
output_path = os.path.join(self.output_dir, f"scene{i+1}.png")
response, saved_image = self.generate_image_with_chat(scene_description, output_path, chat)
responses.append(response)
generated_images.append({
'scene_number': i + 1,
'scene_text': scene,
'image_path': output_path,
'image': saved_image,
'version': 1
})
print(f"Generated scene {i+1}")
self.current_generation.update({
'scenes': scenes,
'generated_images': generated_images,
'chat': chat,
'user_preferences': user_preferences
})
return generated_images, scenes
def generate_manga_from_file(self, story_file_path: str, n_scenes: int = 5, user_preferences: dict = None, user_template=None):
"""Generate manga from story file with user preferences and optional custom template."""
story_text = self.read_story(story_file_path)
return self.generate_manga_from_story(story_text, n_scenes, user_preferences, user_template)
def get_current_panels(self):
"""Get current panel information for the interface."""
if not self.current_generation['generated_images']:
return []
return [(i+1, img['image_path']) for i, img in enumerate(self.current_generation['generated_images'])]
def create_manga_pdf(self, output_filename=None):
"""Create a PDF file from all current manga panels."""
if not self.current_generation['generated_images']:
raise ValueError("No manga panels to export. Please generate manga first.")
if output_filename is None:
output_filename = os.path.join(self.output_dir, "manga_complete.pdf")
c = canvas.Canvas(output_filename, pagesize=A4)
page_width, page_height = A4
c.setFont("Helvetica-Bold", 24)
title_text = "Generated Manga"
title_width = c.stringWidth(title_text, "Helvetica-Bold", 24)
c.drawString((page_width - title_width) / 2, page_height - 100, title_text)
c.setFont("Helvetica", 12)
subtitle_text = f"Total Panels: {len(self.current_generation['generated_images'])}"
subtitle_width = c.stringWidth(subtitle_text, "Helvetica", 12)
c.drawString((page_width - subtitle_width) / 2, page_height - 130, subtitle_text)
current_template = self.get_current_template_path()
if current_template != self.template_path:
template_info = "Custom Template Used"
template_width = c.stringWidth(template_info, "Helvetica", 12)
c.drawString((page_width - template_width) / 2, page_height - 150, template_info)
c.showPage()
for i, panel_data in enumerate(self.current_generation['generated_images']):
image_path = panel_data['image_path']
if os.path.exists(image_path):
img = Image.open(image_path)
img_width, img_height = img.size
aspect_ratio = img_width / img_height
max_width = page_width - 100
max_height = page_height - 150
if aspect_ratio > 1:
new_width = min(max_width, img_width)
new_height = new_width / aspect_ratio
else:
new_height = min(max_height, img_height)
new_width = new_height * aspect_ratio
x = (page_width - new_width) / 2
y = (page_height - new_height) / 2
c.drawImage(image_path, x, y, width=new_width, height=new_height)
c.setFont("Helvetica", 10)
c.drawString(50, page_height - 30, f"Panel {i + 1}")
c.showPage()
c.save()
print(f"PDF saved to: {output_filename}")
return output_filename
# Global session manager
_session_manager = SessionManager()
def get_session_manager() -> SessionManager:
"""Get the global session manager."""
return _session_manager
def get_generator_for_session(session_id: str) -> Optional[MangaGenerator]:
"""Get a manga generator for a specific session."""
session = _session_manager.get_session(session_id)
if session is None:
return None
return MangaGenerator(session)
def create_user_session(api_key: str) -> str:
"""Create a new user session with API key."""
if not api_key or not api_key.strip():
raise ValueError("API key is required")
return _session_manager.create_session(api_key.strip())
# Interface functions with session support
def generate_manga_interface(session_id: str, story_text: str, num_scenes: int = 5, art_style: str = None, mood: str = None,
color_palette: str = None, character_style: str = None, line_style: str = None,
composition: str = None, additional_notes: str = "", user_template=None):
"""Interface function for Gradio - generates manga from text input with session support."""
try:
generator = get_generator_for_session(session_id)
if generator is None:
return [], "Session expired or invalid. Please refresh and enter your API key again."
user_preferences = {}
if art_style and art_style != "None":
user_preferences['art_style'] = art_style
if mood and mood != "None":
user_preferences['mood'] = mood
if color_palette and color_palette != "None":
user_preferences['color_palette'] = color_palette
if character_style and character_style != "None":
user_preferences['character_style'] = character_style
if line_style and line_style != "None":
user_preferences['line_style'] = line_style
if composition and composition != "None":
user_preferences['composition'] = composition
if additional_notes.strip():
user_preferences['additional_notes'] = additional_notes.strip()
generated_images, scenes = generator.generate_manga_from_story(story_text, num_scenes, user_preferences, user_template)
image_paths = [img['image_path'] for img in generated_images]
scene_descriptions = [img['scene_text'] for img in generated_images]
return image_paths, "\n\n".join([f"Scene {i+1}: {scene}" for i, scene in enumerate(scene_descriptions)])
except Exception as e:
print(f"Error generating manga: {e}")
return [], f"Error: {str(e)}"
def generate_manga_from_file_interface(session_id: str, story_file, num_scenes: int = 5, art_style: str = None, mood: str = None,
color_palette: str = None, character_style: str = None, line_style: str = None,
composition: str = None, additional_notes: str = "", user_template=None):
"""Interface function for Gradio - generates manga from uploaded file with session support."""
try:
generator = get_generator_for_session(session_id)
if generator is None:
return [], "Session expired or invalid. Please refresh and enter your API key again."
if hasattr(story_file, 'name'): # Gradio file object
with open(story_file.name, 'r') as f:
story_text = f.read()
else:
story_text = str(story_file)
user_preferences = {}
if art_style and art_style != "None":
user_preferences['art_style'] = art_style
if mood and mood != "None":
user_preferences['mood'] = mood
if color_palette and color_palette != "None":
user_preferences['color_palette'] = color_palette
if character_style and character_style != "None":
user_preferences['character_style'] = character_style
if line_style and line_style != "None":
user_preferences['line_style'] = line_style
if composition and composition != "None":
user_preferences['composition'] = composition
if additional_notes.strip():
user_preferences['additional_notes'] = additional_notes.strip()
generated_images, scenes = generator.generate_manga_from_story(story_text, num_scenes, user_preferences, user_template)
image_paths = [img['image_path'] for img in generated_images]
scene_descriptions = [img['scene_text'] for img in generated_images]
return image_paths, "\n\n".join([f"Scene {i+1}: {scene}" for i, scene in enumerate(scene_descriptions)])
except Exception as e:
print(f"Error generating manga: {e}")
return [], f"Error: {str(e)}"
def regenerate_and_replace_interface(session_id: str, panel_number: int, modification_request: str, replace_original: bool, reference_image=None):
"""Interface function for Gradio - regenerate panel with session support."""
try:
generator = get_generator_for_session(session_id)
if generator is None:
return "Session expired or invalid. Please refresh and enter your API key again.", None, []
if not modification_request.strip():
return "Please provide modification instructions.", None, []
panel_index = panel_number - 1
new_image_path, saved_image = generator.regenerate_specific_panel(panel_index, modification_request, reference_image)
updated_gallery = []
if replace_original and generator.current_generation['generated_images']:
generator.replace_panel(panel_index, new_image_path, saved_image)
updated_gallery = generator.get_current_gallery_paths()
status_message = f"Panel {panel_number} regenerated and replaced successfully!"
else:
status_message = f"Panel {panel_number} regenerated successfully! (Original preserved)"
return status_message, new_image_path, updated_gallery
except Exception as e:
return f"Error regenerating panel: {e}", None, []
def create_pdf_interface(session_id: str):
"""Interface function for Gradio - create PDF with session support."""
try:
generator = get_generator_for_session(session_id)
if generator is None:
return "Session expired or invalid. Please refresh and enter your API key again.", None
if not generator.current_generation['generated_images']:
return "No manga panels to export. Please generate manga first.", None
pdf_path = generator.create_manga_pdf()
message = f"PDF created successfully! ({len(generator.current_generation['generated_images'])} panels)"
return message, pdf_path
except Exception as e:
return f"Error creating PDF: {e}", None
def get_current_panels(session_id: str):
"""Get current panel information for the interface with session support."""
generator = get_generator_for_session(session_id)
if generator is None:
return []
return generator.get_current_panels()