Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,229 +1,283 @@ | |
| 1 | 
            -
             | 
| 2 | 
             
            import os
         | 
| 3 | 
            -
            import  | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 6 | 
            -
             | 
| 7 | 
            -
            import  | 
| 8 | 
            -
            import  | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 34 |  | 
| 35 | 
            -
                 | 
| 36 | 
            -
             | 
|  | |
|  | |
| 37 |  | 
| 38 | 
            -
                 | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
                     | 
| 42 |  | 
| 43 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 44 |  | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 48 |  | 
| 49 | 
            -
                #  | 
| 50 | 
            -
                 | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
                return image_files
         | 
| 55 | 
            -
             | 
| 56 | 
            -
            def load_predefined_images():
         | 
| 57 | 
            -
                # Return empty list since we're not using predefined images
         | 
| 58 | 
            -
                return []
         | 
| 59 | 
            -
             | 
| 60 | 
            -
            @spaces.GPU(duration=120)
         | 
| 61 | 
            -
            def inference(
         | 
| 62 | 
            -
                prompt: str,
         | 
| 63 | 
            -
                seed: int,
         | 
| 64 | 
            -
                randomize_seed: bool,
         | 
| 65 | 
            -
                width: int,
         | 
| 66 | 
            -
                height: int,
         | 
| 67 | 
            -
                guidance_scale: float,
         | 
| 68 | 
            -
                num_inference_steps: int,
         | 
| 69 | 
            -
                lora_scale: float,
         | 
| 70 | 
            -
                progress: gr.Progress = gr.Progress(track_tqdm=True),
         | 
| 71 | 
            -
            ):
         | 
| 72 | 
            -
                if randomize_seed:
         | 
| 73 | 
            -
                    seed = random.randint(0, MAX_SEED)
         | 
| 74 | 
            -
                generator = torch.Generator(device=device).manual_seed(seed)
         | 
| 75 |  | 
| 76 | 
            -
                 | 
| 77 | 
            -
             | 
| 78 | 
            -
                     | 
| 79 | 
            -
                     | 
| 80 | 
            -
                     | 
| 81 | 
            -
                     | 
| 82 | 
            -
                     | 
| 83 | 
            -
                     | 
| 84 | 
            -
             | 
|  | |
| 85 |  | 
| 86 | 
            -
                #  | 
| 87 | 
            -
                 | 
| 88 |  | 
| 89 | 
            -
                 | 
| 90 | 
            -
             | 
| 91 |  | 
| 92 | 
            -
             | 
| 93 | 
            -
                " | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 94 |  | 
| 95 | 
            -
                "A soldier standing at attention in full military gear, holding a standard-issue rifle. His uniform is crisp and properly adorned with medals. Behind him, other soldiers march in formation during a military parade. The scene conveys discipline and duty. [president yoon]",
         | 
| 96 |  | 
| 97 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 98 |  | 
| 99 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 100 |  | 
| 101 | 
            -
                 | 
|  | |
| 102 |  | 
| 103 | 
            -
                 | 
| 104 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 105 |  | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
             | 
| 109 | 
            -
             | 
| 110 | 
            -
            "" | 
| 111 |  | 
| 112 | 
            -
             | 
| 113 | 
            -
             | 
|  | |
|  | |
|  | |
| 114 |  | 
| 115 | 
            -
                 | 
| 116 | 
            -
             | 
| 117 | 
            -
             | 
| 118 | 
            -
             | 
| 119 | 
            -
             | 
| 120 | 
            -
                     | 
| 121 | 
            -
             | 
| 122 | 
            -
             | 
| 123 | 
            -
                                prompt = gr.Text(
         | 
| 124 | 
            -
                                    label="Prompt",
         | 
| 125 | 
            -
                                    show_label=False,
         | 
| 126 | 
            -
                                    max_lines=1,
         | 
| 127 | 
            -
                                    placeholder="Enter your prompt",
         | 
| 128 | 
            -
                                    container=False,
         | 
| 129 | 
            -
                                )
         | 
| 130 | 
            -
                                run_button = gr.Button("Run", scale=0)
         | 
| 131 | 
            -
             | 
| 132 | 
            -
                            result = gr.Image(label="Result", show_label=False)
         | 
| 133 | 
            -
             | 
| 134 | 
            -
                            with gr.Accordion("Advanced Settings", open=False):
         | 
| 135 | 
            -
                                seed = gr.Slider(
         | 
| 136 | 
            -
                                    label="Seed",
         | 
| 137 | 
            -
                                    minimum=0,
         | 
| 138 | 
            -
                                    maximum=MAX_SEED,
         | 
| 139 | 
            -
                                    step=1,
         | 
| 140 | 
            -
                                    value=42,
         | 
| 141 | 
            -
                                )
         | 
| 142 | 
            -
                                randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
         | 
| 143 | 
            -
             | 
| 144 | 
            -
                                with gr.Row():
         | 
| 145 | 
            -
                                    width = gr.Slider(
         | 
| 146 | 
            -
                                        label="Width",
         | 
| 147 | 
            -
                                        minimum=256,
         | 
| 148 | 
            -
                                        maximum=MAX_IMAGE_SIZE,
         | 
| 149 | 
            -
                                        step=32,
         | 
| 150 | 
            -
                                        value=1024,
         | 
| 151 | 
            -
                                    )
         | 
| 152 | 
            -
                                    height = gr.Slider(
         | 
| 153 | 
            -
                                        label="Height",
         | 
| 154 | 
            -
                                        minimum=256,
         | 
| 155 | 
            -
                                        maximum=MAX_IMAGE_SIZE,
         | 
| 156 | 
            -
                                        step=32,
         | 
| 157 | 
            -
                                        value=768,
         | 
| 158 | 
            -
                                    )
         | 
| 159 | 
            -
             | 
| 160 | 
            -
                                with gr.Row():
         | 
| 161 | 
            -
                                    guidance_scale = gr.Slider(
         | 
| 162 | 
            -
                                        label="Guidance scale",
         | 
| 163 | 
            -
                                        minimum=0.0,
         | 
| 164 | 
            -
                                        maximum=10.0,
         | 
| 165 | 
            -
                                        step=0.1,
         | 
| 166 | 
            -
                                        value=3.5,
         | 
| 167 | 
            -
                                    )
         | 
| 168 | 
            -
                                    num_inference_steps = gr.Slider(
         | 
| 169 | 
            -
                                        label="Number of inference steps",
         | 
| 170 | 
            -
                                        minimum=1,
         | 
| 171 | 
            -
                                        maximum=50,
         | 
| 172 | 
            -
                                        step=1,
         | 
| 173 | 
            -
                                        value=30,
         | 
| 174 | 
            -
                                    )
         | 
| 175 | 
            -
                                    lora_scale = gr.Slider(
         | 
| 176 | 
            -
                                        label="LoRA scale",
         | 
| 177 | 
            -
                                        minimum=0.0,
         | 
| 178 | 
            -
                                        maximum=1.0,
         | 
| 179 | 
            -
                                        step=0.1,
         | 
| 180 | 
            -
                                        value=1.0,
         | 
| 181 | 
            -
                                    )
         | 
| 182 | 
            -
             | 
| 183 | 
            -
                            gr.Examples(
         | 
| 184 | 
            -
                                examples=examples,
         | 
| 185 | 
            -
                                inputs=[prompt],
         | 
| 186 | 
            -
                                outputs=[result, seed],
         | 
| 187 | 
            -
                            )
         | 
| 188 | 
            -
             | 
| 189 | 
            -
                    with gr.Tab("Gallery"):
         | 
| 190 | 
            -
                        gallery_header = gr.Markdown("### Generated Images Gallery")
         | 
| 191 | 
            -
                        generated_gallery = gr.Gallery(
         | 
| 192 | 
            -
                            label="Generated Images",
         | 
| 193 | 
            -
                            columns=6,
         | 
| 194 | 
            -
                            show_label=False,
         | 
| 195 | 
            -
                            value=load_generated_images(),
         | 
| 196 | 
            -
                            elem_id="generated_gallery",
         | 
| 197 | 
            -
                            height="auto"
         | 
| 198 | 
            -
                        )
         | 
| 199 | 
            -
                        refresh_btn = gr.Button("π Refresh Gallery")
         | 
| 200 | 
            -
             | 
| 201 | 
            -
             | 
| 202 | 
            -
                # Event handlers
         | 
| 203 | 
            -
                def refresh_gallery():
         | 
| 204 | 
            -
                    return load_generated_images()
         | 
| 205 | 
            -
             | 
| 206 | 
            -
                refresh_btn.click(
         | 
| 207 | 
            -
                    fn=refresh_gallery,
         | 
| 208 | 
            -
                    inputs=None,
         | 
| 209 | 
            -
                    outputs=generated_gallery,
         | 
| 210 | 
            -
                )
         | 
| 211 |  | 
| 212 | 
            -
             | 
| 213 | 
            -
                     | 
| 214 | 
            -
                     | 
| 215 | 
            -
                     | 
| 216 | 
            -
             | 
| 217 | 
            -
             | 
| 218 | 
            -
                         | 
| 219 | 
            -
                         | 
| 220 | 
            -
             | 
| 221 | 
            -
             | 
| 222 | 
            -
             | 
| 223 | 
            -
             | 
| 224 | 
            -
                    ],
         | 
| 225 | 
            -
                    outputs=[result, seed, generated_gallery],
         | 
| 226 | 
            -
                )
         | 
| 227 |  | 
| 228 | 
            -
             | 
| 229 | 
            -
            demo.launch()
         | 
|  | |
| 1 | 
            +
            # app.py
         | 
| 2 | 
             
            import os
         | 
| 3 | 
            +
            import base64
         | 
| 4 | 
            +
            import streamlit as st
         | 
| 5 | 
            +
            from gradio_client import Client
         | 
| 6 | 
            +
            from dotenv import load_dotenv
         | 
| 7 | 
            +
            from pathlib import Path
         | 
| 8 | 
            +
            import json
         | 
| 9 | 
            +
            import hashlib
         | 
| 10 | 
            +
            import time
         | 
| 11 | 
            +
            from typing import Dict, Any
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # Load environment variables
         | 
| 14 | 
            +
            load_dotenv()
         | 
| 15 | 
            +
            HF_TOKEN = os.getenv("HF_TOKEN")
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            # Cache directory setup
         | 
| 18 | 
            +
            CACHE_DIR = Path("./cache")
         | 
| 19 | 
            +
            CACHE_DIR.mkdir(exist_ok=True)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            # Cached example diagrams
         | 
| 22 | 
            +
            CACHED_EXAMPLES = {
         | 
| 23 | 
            +
                "literacy_mental": {
         | 
| 24 | 
            +
                    "title": "Literacy Mental Map",
         | 
| 25 | 
            +
                    "prompt": """A handrawn colorful mind map diagram, rugosity drawn lines, clear shapes, brain silhouette, text areas. must include the texts 
         | 
| 26 | 
            +
                    LITERACY/MENTAL
         | 
| 27 | 
            +
                    βββ PEACE [Dove Icon]
         | 
| 28 | 
            +
                    βββ HEALTH [Vitruvian Man ~60px]
         | 
| 29 | 
            +
                    βββ CONNECT [Brain-Mind Connection Icon]
         | 
| 30 | 
            +
                    βββ INTELLIGENCE
         | 
| 31 | 
            +
                    β   βββ EVERYTHING [Globe Icon ~50px]
         | 
| 32 | 
            +
                    βββ MEMORY
         | 
| 33 | 
            +
                        βββ READING [Book Icon ~40px]
         | 
| 34 | 
            +
                        βββ SPEED [Speedometer Icon]
         | 
| 35 | 
            +
                        βββ CREATIVITY
         | 
| 36 | 
            +
                            βββ INTELLIGENCE [Lightbulb + Infinity ~30px]""",
         | 
| 37 | 
            +
                    "width": 1024,
         | 
| 38 | 
            +
                    "height": 1024,
         | 
| 39 | 
            +
                    "seed": 1872187377,
         | 
| 40 | 
            +
                    "cache_path": "literacy_mental.png"
         | 
| 41 | 
            +
                }
         | 
| 42 | 
            +
            }
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            # Example diagrams for various use cases
         | 
| 45 | 
            +
            DIAGRAM_EXAMPLES = [
         | 
| 46 | 
            +
                {
         | 
| 47 | 
            +
                    "title": "Project Management Flow",
         | 
| 48 | 
            +
                    "prompt": """A handrawn colorful mind map diagram, rugosity drawn lines, clear shapes, project management flow.
         | 
| 49 | 
            +
                    PROJECT MANAGEMENT
         | 
| 50 | 
            +
                    βββ INITIATION [Rocket Icon]
         | 
| 51 | 
            +
                    βββ PLANNING [Calendar Icon]
         | 
| 52 | 
            +
                    βββ EXECUTION [Gear Icon]
         | 
| 53 | 
            +
                    βββ MONITORING
         | 
| 54 | 
            +
                    β   βββ CONTROL [Dashboard Icon]
         | 
| 55 | 
            +
                    βββ CLOSURE [Checkmark Icon]""",
         | 
| 56 | 
            +
                    "width": 1024,
         | 
| 57 | 
            +
                    "height": 1024
         | 
| 58 | 
            +
                },
         | 
| 59 | 
            +
                {
         | 
| 60 | 
            +
                    "title": "Digital Marketing Strategy",
         | 
| 61 | 
            +
                    "prompt": """A handrawn colorful mind map diagram, rugosity drawn lines, modern style, marketing concept.
         | 
| 62 | 
            +
                    DIGITAL MARKETING
         | 
| 63 | 
            +
                    βββ SEO [Magnifying Glass]
         | 
| 64 | 
            +
                    βββ SOCIAL MEDIA [Network Icon]
         | 
| 65 | 
            +
                    βββ CONTENT
         | 
| 66 | 
            +
                    β   βββ BLOG [Document Icon]
         | 
| 67 | 
            +
                    β   βββ VIDEO [Play Button]
         | 
| 68 | 
            +
                    βββ ANALYTICS [Graph Icon]""",
         | 
| 69 | 
            +
                    "width": 1024,
         | 
| 70 | 
            +
                    "height": 1024
         | 
| 71 | 
            +
                }
         | 
| 72 | 
            +
            ]
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            # Add 15 more examples
         | 
| 75 | 
            +
            ADDITIONAL_EXAMPLES = [
         | 
| 76 | 
            +
                {
         | 
| 77 | 
            +
                    "title": "Health & Wellness",
         | 
| 78 | 
            +
                    "prompt": """A handrawn colorful mind map diagram, wellness-focused style, health aspects.
         | 
| 79 | 
            +
                    WELLNESS
         | 
| 80 | 
            +
                    βββ PHYSICAL [Dumbbell Icon]
         | 
| 81 | 
            +
                    βββ MENTAL [Brain Icon]
         | 
| 82 | 
            +
                    βββ NUTRITION [Apple Icon]
         | 
| 83 | 
            +
                    βββ SLEEP
         | 
| 84 | 
            +
                        βββ QUALITY [Star Icon]
         | 
| 85 | 
            +
                        βββ DURATION [Clock Icon]""",
         | 
| 86 | 
            +
                    "width": 1024,
         | 
| 87 | 
            +
                    "height": 1024
         | 
| 88 | 
            +
                }
         | 
| 89 | 
            +
                # ... (λλ¨Έμ§ μμ λ€)
         | 
| 90 | 
            +
            ]
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
             | 
| 101 | 
            +
             | 
| 102 | 
            +
            class DiagramCache:
         | 
| 103 | 
            +
                def __init__(self, cache_dir: Path):
         | 
| 104 | 
            +
                    self.cache_dir = cache_dir
         | 
| 105 | 
            +
                    self.cache_dir.mkdir(exist_ok=True)
         | 
| 106 | 
            +
                    self._load_cache()
         | 
| 107 | 
            +
                
         | 
| 108 | 
            +
                def _load_cache(self):
         | 
| 109 | 
            +
                    """Load existing cache entries"""
         | 
| 110 | 
            +
                    self.cache_index = {}
         | 
| 111 | 
            +
                    if (self.cache_dir / "cache_index.json").exists():
         | 
| 112 | 
            +
                        with open(self.cache_dir / "cache_index.json", "r") as f:
         | 
| 113 | 
            +
                            self.cache_index = json.load(f)
         | 
| 114 |  | 
| 115 | 
            +
                def _save_cache_index(self):
         | 
| 116 | 
            +
                    """Save cache index to disk"""
         | 
| 117 | 
            +
                    with open(self.cache_dir / "cache_index.json", "w") as f:
         | 
| 118 | 
            +
                        json.dump(self.cache_index, f)
         | 
| 119 |  | 
| 120 | 
            +
                def _get_cache_key(self, params: Dict[str, Any]) -> str:
         | 
| 121 | 
            +
                    """Generate cache key from parameters"""
         | 
| 122 | 
            +
                    param_str = json.dumps(params, sort_keys=True)
         | 
| 123 | 
            +
                    return hashlib.md5(param_str.encode()).hexdigest()
         | 
| 124 |  | 
| 125 | 
            +
                def get(self, params: Dict[str, Any]) -> Path:
         | 
| 126 | 
            +
                    """Get cached result if exists"""
         | 
| 127 | 
            +
                    cache_key = self._get_cache_key(params)
         | 
| 128 | 
            +
                    cache_info = self.cache_index.get(cache_key)
         | 
| 129 | 
            +
                    if cache_info:
         | 
| 130 | 
            +
                        cache_path = self.cache_dir / cache_info["filename"]
         | 
| 131 | 
            +
                        if cache_path.exists():
         | 
| 132 | 
            +
                            return cache_path
         | 
| 133 | 
            +
                    return None
         | 
| 134 | 
            +
                
         | 
| 135 | 
            +
                def put(self, params: Dict[str, Any], result_path: Path):
         | 
| 136 | 
            +
                    """Store result in cache"""
         | 
| 137 | 
            +
                    cache_key = self._get_cache_key(params)
         | 
| 138 | 
            +
                    filename = f"{cache_key}{result_path.suffix}"
         | 
| 139 | 
            +
                    cache_path = self.cache_dir / filename
         | 
| 140 | 
            +
                    
         | 
| 141 | 
            +
                    # Copy result to cache
         | 
| 142 | 
            +
                    with open(result_path, "rb") as src, open(cache_path, "wb") as dst:
         | 
| 143 | 
            +
                        dst.write(src.read())
         | 
| 144 | 
            +
                    
         | 
| 145 | 
            +
                    # Update index
         | 
| 146 | 
            +
                    self.cache_index[cache_key] = {
         | 
| 147 | 
            +
                        "filename": filename,
         | 
| 148 | 
            +
                        "timestamp": time.time(),
         | 
| 149 | 
            +
                        "params": params
         | 
| 150 | 
            +
                    }
         | 
| 151 | 
            +
                    self._save_cache_index()
         | 
| 152 | 
            +
             | 
| 153 |  | 
| 154 | 
            +
            # Initialize cache
         | 
| 155 | 
            +
            diagram_cache = DiagramCache(CACHE_DIR)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
            @st.cache_data
         | 
| 158 | 
            +
            def generate_cached_example(example_id: str) -> str:
         | 
| 159 | 
            +
                """Generate and cache example diagram"""
         | 
| 160 | 
            +
                example = CACHED_EXAMPLES[example_id]
         | 
| 161 | 
            +
                client = Client("black-forest-labs/FLUX.1-schnell")
         | 
| 162 |  | 
| 163 | 
            +
                # Check cache first
         | 
| 164 | 
            +
                cache_path = diagram_cache.get(example)
         | 
| 165 | 
            +
                if cache_path:
         | 
| 166 | 
            +
                    with open(cache_path, "rb") as f:
         | 
| 167 | 
            +
                        return base64.b64encode(f.read()).decode()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 168 |  | 
| 169 | 
            +
                # Generate new image
         | 
| 170 | 
            +
                result = client.predict(
         | 
| 171 | 
            +
                    prompt=example["prompt"],
         | 
| 172 | 
            +
                    seed=example["seed"],
         | 
| 173 | 
            +
                    randomize_seed=False,
         | 
| 174 | 
            +
                    width=example["width"],
         | 
| 175 | 
            +
                    height=example["height"],
         | 
| 176 | 
            +
                    num_inference_steps=4,
         | 
| 177 | 
            +
                    api_name="/infer"
         | 
| 178 | 
            +
                )
         | 
| 179 |  | 
| 180 | 
            +
                # Cache the result
         | 
| 181 | 
            +
                diagram_cache.put(example, Path(result))
         | 
| 182 |  | 
| 183 | 
            +
                with open(result, "rb") as f:
         | 
| 184 | 
            +
                    return base64.b64encode(f.read()).decode()
         | 
| 185 |  | 
| 186 | 
            +
            def generate_diagram(prompt: str, width: int, height: int, seed: int = None) -> str:
         | 
| 187 | 
            +
                """Generate a new diagram"""
         | 
| 188 | 
            +
                client = Client("black-forest-labs/FLUX.1-schnell")
         | 
| 189 | 
            +
                params = {
         | 
| 190 | 
            +
                    "prompt": prompt,
         | 
| 191 | 
            +
                    "seed": seed if seed else 1872187377,
         | 
| 192 | 
            +
                    "width": width,
         | 
| 193 | 
            +
                    "height": height
         | 
| 194 | 
            +
                }
         | 
| 195 | 
            +
                
         | 
| 196 | 
            +
                # Check cache first
         | 
| 197 | 
            +
                cache_path = diagram_cache.get(params)
         | 
| 198 | 
            +
                if cache_path:
         | 
| 199 | 
            +
                    with open(cache_path, "rb") as f:
         | 
| 200 | 
            +
                        return base64.b64encode(f.read()).decode()
         | 
| 201 | 
            +
                
         | 
| 202 | 
            +
                # Generate new image
         | 
| 203 | 
            +
                try:
         | 
| 204 | 
            +
                    result = client.predict(
         | 
| 205 | 
            +
                        prompt=prompt,
         | 
| 206 | 
            +
                        seed=params["seed"],
         | 
| 207 | 
            +
                        randomize_seed=False,
         | 
| 208 | 
            +
                        width=width,
         | 
| 209 | 
            +
                        height=height,
         | 
| 210 | 
            +
                        num_inference_steps=4,
         | 
| 211 | 
            +
                        api_name="/infer"
         | 
| 212 | 
            +
                    )
         | 
| 213 | 
            +
                    
         | 
| 214 | 
            +
                    # Cache the result
         | 
| 215 | 
            +
                    diagram_cache.put(params, Path(result))
         | 
| 216 | 
            +
                    
         | 
| 217 | 
            +
                    with open(result, "rb") as f:
         | 
| 218 | 
            +
                        return base64.b64encode(f.read()).decode()
         | 
| 219 | 
            +
                except Exception as e:
         | 
| 220 | 
            +
                    st.error(f"Error generating diagram: {str(e)}")
         | 
| 221 | 
            +
                    return None
         | 
| 222 |  | 
|  | |
| 223 |  | 
| 224 | 
            +
            def main():
         | 
| 225 | 
            +
                st.set_page_config(page_title="FLUX Diagram Generator", layout="wide")
         | 
| 226 | 
            +
                
         | 
| 227 | 
            +
                st.title("π¨ FLUX Diagram Generator")
         | 
| 228 | 
            +
                st.markdown("Generate beautiful hand-drawn style diagrams using FLUX AI")
         | 
| 229 |  | 
| 230 | 
            +
                # Sidebar for examples
         | 
| 231 | 
            +
                st.sidebar.title("π Example Templates")
         | 
| 232 | 
            +
                selected_example = st.sidebar.selectbox(
         | 
| 233 | 
            +
                    "Choose a template",
         | 
| 234 | 
            +
                    options=range(len(DIAGRAM_EXAMPLES)),
         | 
| 235 | 
            +
                    format_func=lambda x: DIAGRAM_EXAMPLES[x]["title"]
         | 
| 236 | 
            +
                )
         | 
| 237 |  | 
| 238 | 
            +
                # Main content area
         | 
| 239 | 
            +
                col1, col2 = st.columns([2, 1])
         | 
| 240 |  | 
| 241 | 
            +
                with col1:
         | 
| 242 | 
            +
                    # Input area
         | 
| 243 | 
            +
                    prompt = st.text_area(
         | 
| 244 | 
            +
                        "Diagram Prompt",
         | 
| 245 | 
            +
                        value=DIAGRAM_EXAMPLES[selected_example]["prompt"],
         | 
| 246 | 
            +
                        height=200
         | 
| 247 | 
            +
                    )
         | 
| 248 |  | 
| 249 | 
            +
                    # Configuration
         | 
| 250 | 
            +
                    with st.expander("Advanced Configuration"):
         | 
| 251 | 
            +
                        width = st.number_input("Width", min_value=512, max_value=2048, value=1024, step=128)
         | 
| 252 | 
            +
                        height = st.number_input("Height", min_value=512, max_value=2048, value=1024, step=128)
         | 
| 253 | 
            +
                        seed = st.number_input("Seed (optional)", value=None, step=1)
         | 
| 254 |  | 
| 255 | 
            +
                    if st.button("π¨ Generate Diagram"):
         | 
| 256 | 
            +
                        with st.spinner("Generating your diagram..."):
         | 
| 257 | 
            +
                            result = generate_diagram(prompt, width, height, seed)
         | 
| 258 | 
            +
                            if result:
         | 
| 259 | 
            +
                                st.image(result, caption="Generated Diagram", use_column_width=True)
         | 
| 260 |  | 
| 261 | 
            +
                with col2:
         | 
| 262 | 
            +
                    st.subheader("Tips for Better Results")
         | 
| 263 | 
            +
                    st.markdown("""
         | 
| 264 | 
            +
                    - Use clear hierarchical structures
         | 
| 265 | 
            +
                    - Include icon descriptions in brackets
         | 
| 266 | 
            +
                    - Keep text concise and meaningful
         | 
| 267 | 
            +
                    - Use consistent formatting
         | 
| 268 | 
            +
                    """)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 269 |  | 
| 270 | 
            +
                    st.subheader("Template Structure")
         | 
| 271 | 
            +
                    st.code("""
         | 
| 272 | 
            +
                    MAIN TOPIC
         | 
| 273 | 
            +
                    βββ SUBTOPIC 1 [Icon]
         | 
| 274 | 
            +
                    βββ SUBTOPIC 2 [Icon]
         | 
| 275 | 
            +
                    βββ SUBTOPIC 3
         | 
| 276 | 
            +
                        βββ DETAIL 1 [Icon]
         | 
| 277 | 
            +
                        βββ DETAIL 2 [Icon]
         | 
| 278 | 
            +
                    """)
         | 
| 279 | 
            +
             | 
| 280 | 
            +
            if __name__ == "__main__":
         | 
| 281 | 
            +
                main()
         | 
|  | |
|  | |
|  | |
| 282 |  | 
| 283 | 
            +
                
         | 
|  | 
 
			
