Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		Fix issues about ZeroGPU and examples (#2)
Browse files- Delete .DS_Store and __pycache__ (c5e30ab8e6a0071c198b2b6a7cfe16b45de6c673)
- Add .gitignore (ecda16070478ee8ba8913bba520fc1d9fb4c80c1)
- Apply formatter to app.py and requirements.txt (33cb6e3e0847e54a8fce301fb39b6340beb34b1e)
- Clean up (5719c29c84ab2296306a92e39b5cc38f57a5bdb9)
- Remove unused import (c37294a743eb4be148f3d8d65c30d423efefc24c)
- Use huggingface_hub to download models (634953839cbc8dd333a76dd302f0c241b9c2f491)
- format (1109e54a30722d3a0bec94c8a4e98941b478555f)
- Change how gr.State is used (49f5f360f59a322ca015e5189a43a2b665a8a112)
- Add error handling (de8155f9965bbe0a4e2d70c508ca29a01d802a6d)
- Add error handling (0e7235bc177bb3a74a4f22a2b2a90bc6dbcfb781)
- Add error handling (5336236acbd5cf22c374422ddbddf6458842781d)
- Process examples when loaded (661a9c1678c0d5bc9e0f952e541d4b83951d1a2e)
Co-authored-by: hysts <[email protected]>
- .DS_Store +0 -0
- .gitignore +162 -0
- __asset__/.DS_Store +0 -0
- __asset__/images/.DS_Store +0 -0
- __asset__/images/camera/.DS_Store +0 -0
- __asset__/images/object/.DS_Store +0 -0
- __asset__/trajs/.DS_Store +0 -0
- __asset__/trajs/camera/.DS_Store +0 -0
- __asset__/trajs/object/.DS_Store +0 -0
- app.py +445 -343
- configs/.DS_Store +0 -0
- models/.DS_Store +0 -0
- modules/__pycache__/attention.cpython-310.pyc +0 -0
- modules/__pycache__/flow_controlnet.cpython-310.pyc +0 -0
- modules/__pycache__/image_controlnet.cpython-310.pyc +0 -0
- modules/__pycache__/motion_module.cpython-310.pyc +0 -0
- modules/__pycache__/resnet.cpython-310.pyc +0 -0
- modules/__pycache__/unet.cpython-310.pyc +0 -0
- modules/__pycache__/unet_blocks.cpython-310.pyc +0 -0
- peft/__pycache__/__init__.cpython-310.pyc +0 -0
- peft/__pycache__/auto.cpython-310.pyc +0 -0
- peft/__pycache__/config.cpython-310.pyc +0 -0
- peft/__pycache__/import_utils.cpython-310.pyc +0 -0
- peft/__pycache__/mapping.cpython-310.pyc +0 -0
- peft/__pycache__/mixed_model.cpython-310.pyc +0 -0
- peft/__pycache__/peft_model.cpython-310.pyc +0 -0
- peft/tuners/__pycache__/__init__.cpython-310.pyc +0 -0
- peft/tuners/__pycache__/lycoris_utils.cpython-310.pyc +0 -0
- peft/tuners/__pycache__/tuners_utils.cpython-310.pyc +0 -0
- peft/tuners/adalora/__pycache__/__init__.cpython-310.pyc +0 -0
- peft/tuners/adalora/__pycache__/bnb.cpython-310.pyc +0 -0
- peft/tuners/adalora/__pycache__/config.cpython-310.pyc +0 -0
- peft/tuners/adalora/__pycache__/gptq.cpython-310.pyc +0 -0
- peft/tuners/adalora/__pycache__/layer.cpython-310.pyc +0 -0
- peft/tuners/adalora/__pycache__/model.cpython-310.pyc +0 -0
- peft/tuners/adaption_prompt/__pycache__/__init__.cpython-310.pyc +0 -0
- peft/tuners/adaption_prompt/__pycache__/config.cpython-310.pyc +0 -0
- peft/tuners/adaption_prompt/__pycache__/layer.cpython-310.pyc +0 -0
- peft/tuners/adaption_prompt/__pycache__/model.cpython-310.pyc +0 -0
- peft/tuners/adaption_prompt/__pycache__/utils.cpython-310.pyc +0 -0
- peft/tuners/boft/__pycache__/__init__.cpython-310.pyc +0 -0
- peft/tuners/boft/__pycache__/config.cpython-310.pyc +0 -0
- peft/tuners/boft/__pycache__/layer.cpython-310.pyc +0 -0
- peft/tuners/boft/__pycache__/model.cpython-310.pyc +0 -0
- peft/tuners/boft/fbd/__pycache__/__init__.cpython-310.pyc +0 -0
- peft/tuners/ia3/__pycache__/__init__.cpython-310.pyc +0 -0
- peft/tuners/ia3/__pycache__/bnb.cpython-310.pyc +0 -0
- peft/tuners/ia3/__pycache__/config.cpython-310.pyc +0 -0
- peft/tuners/ia3/__pycache__/layer.cpython-310.pyc +0 -0
- peft/tuners/ia3/__pycache__/model.cpython-310.pyc +0 -0
| Binary file (8.2 kB) | 
|  | 
| @@ -0,0 +1,162 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Byte-compiled / optimized / DLL files
         | 
| 2 | 
            +
            __pycache__/
         | 
| 3 | 
            +
            *.py[cod]
         | 
| 4 | 
            +
            *$py.class
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # C extensions
         | 
| 7 | 
            +
            *.so
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Distribution / packaging
         | 
| 10 | 
            +
            .Python
         | 
| 11 | 
            +
            build/
         | 
| 12 | 
            +
            develop-eggs/
         | 
| 13 | 
            +
            dist/
         | 
| 14 | 
            +
            downloads/
         | 
| 15 | 
            +
            eggs/
         | 
| 16 | 
            +
            .eggs/
         | 
| 17 | 
            +
            lib/
         | 
| 18 | 
            +
            lib64/
         | 
| 19 | 
            +
            parts/
         | 
| 20 | 
            +
            sdist/
         | 
| 21 | 
            +
            var/
         | 
| 22 | 
            +
            wheels/
         | 
| 23 | 
            +
            share/python-wheels/
         | 
| 24 | 
            +
            *.egg-info/
         | 
| 25 | 
            +
            .installed.cfg
         | 
| 26 | 
            +
            *.egg
         | 
| 27 | 
            +
            MANIFEST
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            # PyInstaller
         | 
| 30 | 
            +
            #  Usually these files are written by a python script from a template
         | 
| 31 | 
            +
            #  before PyInstaller builds the exe, so as to inject date/other infos into it.
         | 
| 32 | 
            +
            *.manifest
         | 
| 33 | 
            +
            *.spec
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            # Installer logs
         | 
| 36 | 
            +
            pip-log.txt
         | 
| 37 | 
            +
            pip-delete-this-directory.txt
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            # Unit test / coverage reports
         | 
| 40 | 
            +
            htmlcov/
         | 
| 41 | 
            +
            .tox/
         | 
| 42 | 
            +
            .nox/
         | 
| 43 | 
            +
            .coverage
         | 
| 44 | 
            +
            .coverage.*
         | 
| 45 | 
            +
            .cache
         | 
| 46 | 
            +
            nosetests.xml
         | 
| 47 | 
            +
            coverage.xml
         | 
| 48 | 
            +
            *.cover
         | 
| 49 | 
            +
            *.py,cover
         | 
| 50 | 
            +
            .hypothesis/
         | 
| 51 | 
            +
            .pytest_cache/
         | 
| 52 | 
            +
            cover/
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            # Translations
         | 
| 55 | 
            +
            *.mo
         | 
| 56 | 
            +
            *.pot
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            # Django stuff:
         | 
| 59 | 
            +
            *.log
         | 
| 60 | 
            +
            local_settings.py
         | 
| 61 | 
            +
            db.sqlite3
         | 
| 62 | 
            +
            db.sqlite3-journal
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            # Flask stuff:
         | 
| 65 | 
            +
            instance/
         | 
| 66 | 
            +
            .webassets-cache
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            # Scrapy stuff:
         | 
| 69 | 
            +
            .scrapy
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            # Sphinx documentation
         | 
| 72 | 
            +
            docs/_build/
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            # PyBuilder
         | 
| 75 | 
            +
            .pybuilder/
         | 
| 76 | 
            +
            target/
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            # Jupyter Notebook
         | 
| 79 | 
            +
            .ipynb_checkpoints
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            # IPython
         | 
| 82 | 
            +
            profile_default/
         | 
| 83 | 
            +
            ipython_config.py
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            # pyenv
         | 
| 86 | 
            +
            #   For a library or package, you might want to ignore these files since the code is
         | 
| 87 | 
            +
            #   intended to run in multiple environments; otherwise, check them in:
         | 
| 88 | 
            +
            # .python-version
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            # pipenv
         | 
| 91 | 
            +
            #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
         | 
| 92 | 
            +
            #   However, in case of collaboration, if having platform-specific dependencies or dependencies
         | 
| 93 | 
            +
            #   having no cross-platform support, pipenv may install dependencies that don't work, or not
         | 
| 94 | 
            +
            #   install all needed dependencies.
         | 
| 95 | 
            +
            #Pipfile.lock
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            # poetry
         | 
| 98 | 
            +
            #   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
         | 
| 99 | 
            +
            #   This is especially recommended for binary packages to ensure reproducibility, and is more
         | 
| 100 | 
            +
            #   commonly ignored for libraries.
         | 
| 101 | 
            +
            #   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
         | 
| 102 | 
            +
            #poetry.lock
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            # pdm
         | 
| 105 | 
            +
            #   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
         | 
| 106 | 
            +
            #pdm.lock
         | 
| 107 | 
            +
            #   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
         | 
| 108 | 
            +
            #   in version control.
         | 
| 109 | 
            +
            #   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
         | 
| 110 | 
            +
            .pdm.toml
         | 
| 111 | 
            +
            .pdm-python
         | 
| 112 | 
            +
            .pdm-build/
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
         | 
| 115 | 
            +
            __pypackages__/
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            # Celery stuff
         | 
| 118 | 
            +
            celerybeat-schedule
         | 
| 119 | 
            +
            celerybeat.pid
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            # SageMath parsed files
         | 
| 122 | 
            +
            *.sage.py
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            # Environments
         | 
| 125 | 
            +
            .env
         | 
| 126 | 
            +
            .venv
         | 
| 127 | 
            +
            env/
         | 
| 128 | 
            +
            venv/
         | 
| 129 | 
            +
            ENV/
         | 
| 130 | 
            +
            env.bak/
         | 
| 131 | 
            +
            venv.bak/
         | 
| 132 | 
            +
             | 
| 133 | 
            +
            # Spyder project settings
         | 
| 134 | 
            +
            .spyderproject
         | 
| 135 | 
            +
            .spyproject
         | 
| 136 | 
            +
             | 
| 137 | 
            +
            # Rope project settings
         | 
| 138 | 
            +
            .ropeproject
         | 
| 139 | 
            +
             | 
| 140 | 
            +
            # mkdocs documentation
         | 
| 141 | 
            +
            /site
         | 
| 142 | 
            +
             | 
| 143 | 
            +
            # mypy
         | 
| 144 | 
            +
            .mypy_cache/
         | 
| 145 | 
            +
            .dmypy.json
         | 
| 146 | 
            +
            dmypy.json
         | 
| 147 | 
            +
             | 
| 148 | 
            +
            # Pyre type checker
         | 
| 149 | 
            +
            .pyre/
         | 
| 150 | 
            +
             | 
| 151 | 
            +
            # pytype static type analyzer
         | 
| 152 | 
            +
            .pytype/
         | 
| 153 | 
            +
             | 
| 154 | 
            +
            # Cython debug symbols
         | 
| 155 | 
            +
            cython_debug/
         | 
| 156 | 
            +
             | 
| 157 | 
            +
            # PyCharm
         | 
| 158 | 
            +
            #  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
         | 
| 159 | 
            +
            #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
         | 
| 160 | 
            +
            #  and can be added to the global gitignore or merged into this file.  For a more nuclear
         | 
| 161 | 
            +
            #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
         | 
| 162 | 
            +
            #.idea/
         | 
| Binary file (6.15 kB) | 
|  | 
| Binary file (6.15 kB) | 
|  | 
| Binary file (6.15 kB) | 
|  | 
| Binary file (6.15 kB) | 
|  | 
| Binary file (6.15 kB) | 
|  | 
| Binary file (6.15 kB) | 
|  | 
| Binary file (6.15 kB) | 
|  | 
| @@ -1,35 +1,35 @@ | |
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
            -
            import  | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
            print("Installing correct gradio version...")
         | 
| 6 | 
            -
            os.system("pip uninstall -y gradio")
         | 
| 7 | 
            -
            os.system("pip install gradio==4.38.1")
         | 
| 8 | 
            -
            print("Installing Finished!")
         | 
| 9 | 
            -
             | 
| 10 |  | 
|  | |
| 11 | 
             
            import gradio as gr
         | 
| 12 | 
             
            import numpy as np
         | 
| 13 | 
            -
            import  | 
| 14 | 
            -
            import uuid
         | 
| 15 | 
             
            import torch
         | 
| 16 | 
             
            import torchvision
         | 
| 17 | 
            -
            import  | 
| 18 | 
            -
            import  | 
| 19 | 
            -
             | 
| 20 | 
            -
            from PIL import Image
         | 
| 21 | 
             
            from omegaconf import OmegaConf
         | 
| 22 | 
            -
            from  | 
| 23 | 
            -
            from torchvision import transforms | 
| 24 | 
             
            from transformers import CLIPTextModel, CLIPTokenizer
         | 
| 25 | 
            -
            from diffusers import AutoencoderKL, DDIMScheduler
         | 
| 26 |  | 
| 27 | 
            -
            from pipelines.pipeline_imagecoductor import ImageConductorPipeline
         | 
| 28 | 
             
            from modules.unet import UNet3DConditionFlowModel
         | 
| 29 | 
            -
            from  | 
| 30 | 
            -
            from utils. | 
| 31 | 
             
            from utils.lora_utils import add_LoRA_to_controlnet
         | 
| 32 | 
            -
            from utils. | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 33 | 
             
            #### Description ####
         | 
| 34 | 
             
            title = r"""<h1 align="center">CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models</h1>"""
         | 
| 35 |  | 
| @@ -41,7 +41,7 @@ head = r""" | |
| 41 | 
             
                                        <a href='https://liyaowei-stu.github.io/project/ImageConductor/'><img src='https://img.shields.io/badge/Project_Page-ImgaeConductor-green' alt='Project Page'></a>
         | 
| 42 | 
             
                                        <a href='https://arxiv.org/pdf/2406.15339'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
         | 
| 43 | 
             
                                        <a href='https://github.com/liyaowei-stu/ImageConductor'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
         | 
| 44 | 
            -
             | 
| 45 |  | 
| 46 | 
             
                                    </div>
         | 
| 47 | 
             
                                    </br>
         | 
| @@ -49,7 +49,6 @@ head = r""" | |
| 49 | 
             
            """
         | 
| 50 |  | 
| 51 |  | 
| 52 | 
            -
             | 
| 53 | 
             
            descriptions = r"""
         | 
| 54 | 
             
            Official Gradio Demo for <a href='https://github.com/liyaowei-stu/ImageConductor'><b>Image Conductor: Precision Control for Interactive Video Synthesis</b></a>.<br>
         | 
| 55 | 
             
            🧙Image Conductor enables precise, fine-grained control for generating motion-controllable videos from images, advancing the practical application of interactive video synthesis.<br>
         | 
| @@ -66,7 +65,7 @@ instructions = r""" | |
| 66 | 
             
                        """
         | 
| 67 |  | 
| 68 | 
             
            citation = r"""
         | 
| 69 | 
            -
            If Image Conductor is helpful, please help to ⭐ the <a href='https://github.com/liyaowei-stu/ImageConductor' target='_blank'>Github Repo</a>. Thanks! | 
| 70 | 
             
            [](https://github.com/liyaowei-stu/ImageConductor)
         | 
| 71 | 
             
            ---
         | 
| 72 |  | 
| @@ -75,7 +74,7 @@ If Image Conductor is helpful, please help to ⭐ the <a href='https://github.co | |
| 75 | 
             
            If our work is useful for your research, please consider citing:
         | 
| 76 | 
             
            ```bibtex
         | 
| 77 | 
             
            @misc{li2024imageconductor,
         | 
| 78 | 
            -
                title={Image Conductor: Precision Control for Interactive Video Synthesis}, | 
| 79 | 
             
                author={Li, Yaowei and Wang, Xintao and Zhang, Zhaoyang and Wang, Zhouxia and Yuan, Ziyang and Xie, Liangbin and Zou, Yuexian and Shan, Ying},
         | 
| 80 | 
             
                year={2024},
         | 
| 81 | 
             
                eprint={2406.15339},
         | 
| @@ -90,46 +89,19 @@ If you have any questions, please feel free to reach me out at <b>[email protected] | |
| 90 |  | 
| 91 | 
             
            # """
         | 
| 92 |  | 
| 93 | 
            -
             | 
| 94 | 
            -
             | 
| 95 | 
            -
             | 
| 96 | 
            -
            if not os.path.exists("models/flow_controlnet.ckpt"):
         | 
| 97 | 
            -
                os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/flow_controlnet.ckpt?download=true -P models/')
         | 
| 98 | 
            -
                os.system(f'mv models/flow_controlnet.ckpt?download=true models/flow_controlnet.ckpt')
         | 
| 99 | 
            -
                print("flow_controlnet Download!", )
         | 
| 100 | 
            -
             | 
| 101 | 
            -
            if not os.path.exists("models/image_controlnet.ckpt"):
         | 
| 102 | 
            -
                os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/image_controlnet.ckpt?download=true -P models/')
         | 
| 103 | 
            -
                os.system(f'mv models/image_controlnet.ckpt?download=true models/image_controlnet.ckpt')
         | 
| 104 | 
            -
                print("image_controlnet Download!", )
         | 
| 105 |  | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
                os.system(f'mv models/unet.ckpt?download=true models/unet.ckpt')
         | 
| 109 | 
            -
                print("unet Download!", )
         | 
| 110 |  | 
| 111 | 
            -
             | 
|  | |
| 112 | 
             
            if not os.path.exists("models/sd1-5/config.json"):
         | 
| 113 | 
            -
                os. | 
| 114 | 
            -
                os.system(f'mv models/sd1-5/config.json?download=true  models/sd1-5/config.json')
         | 
| 115 | 
            -
                print("config Download!", )
         | 
| 116 | 
            -
             | 
| 117 | 
            -
             | 
| 118 | 
             
            if not os.path.exists("models/sd1-5/unet.ckpt"):
         | 
| 119 | 
            -
                os. | 
| 120 | 
            -
             | 
| 121 | 
            -
            # os.system(f'wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/diffusion_pytorch_model.bin?download=true -P models/sd1-5/')
         | 
| 122 | 
            -
             | 
| 123 | 
            -
            if not os.path.exists("models/personalized/helloobjects_V12c.safetensors"):
         | 
| 124 | 
            -
                os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/helloobjects_V12c.safetensors?download=true -P models/personalized')
         | 
| 125 | 
            -
                os.system(f'mv models/personalized/helloobjects_V12c.safetensors?download=true models/personalized/helloobjects_V12c.safetensors')
         | 
| 126 | 
            -
                print("helloobjects_V12c Download!", )
         | 
| 127 | 
            -
             | 
| 128 | 
            -
             | 
| 129 | 
            -
            if not os.path.exists("models/personalized/TUSUN.safetensors"):
         | 
| 130 | 
            -
                os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/TUSUN.safetensors?download=true -P models/personalized')
         | 
| 131 | 
            -
                os.system(f'mv models/personalized/TUSUN.safetensors?download=true models/personalized/TUSUN.safetensors')
         | 
| 132 | 
            -
                print("TUSUN Download!", )
         | 
| 133 |  | 
| 134 | 
             
            # mv1 = os.system(f'mv /usr/local/lib/python3.10/site-packages/gradio/helpers.py /usr/local/lib/python3.10/site-packages/gradio/helpers_bkp.py')
         | 
| 135 | 
             
            # mv2 = os.system(f'mv helpers.py /usr/local/lib/python3.10/site-packages/gradio/helpers.py')
         | 
| @@ -145,128 +117,135 @@ if not os.path.exists("models/personalized/TUSUN.safetensors"): | |
| 145 | 
             
            # - - - - - examples  - - - - -  #
         | 
| 146 |  | 
| 147 | 
             
            image_examples = [
         | 
| 148 | 
            -
                [ | 
| 149 | 
            -
             | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
| 152 | 
            -
             | 
| 153 | 
            -
             | 
| 154 | 
            -
             | 
| 155 | 
            -
             | 
| 156 | 
            -
                
         | 
| 157 | 
            -
                [ | 
| 158 | 
            -
             | 
| 159 | 
            -
             | 
| 160 | 
            -
             | 
| 161 | 
            -
             | 
| 162 | 
            -
             | 
| 163 | 
            -
             | 
| 164 | 
            -
             | 
| 165 | 
            -
                
         | 
| 166 | 
            -
                [ | 
| 167 | 
            -
             | 
| 168 | 
            -
             | 
| 169 | 
            -
             | 
| 170 | 
            -
             | 
| 171 | 
            -
             | 
| 172 | 
            -
             | 
| 173 | 
            -
             | 
| 174 | 
            -
                
         | 
| 175 | 
            -
                
         | 
| 176 | 
            -
             | 
| 177 | 
            -
             | 
| 178 | 
            -
             | 
| 179 | 
            -
             | 
| 180 | 
            -
             | 
| 181 | 
            -
             | 
| 182 | 
            -
             | 
| 183 | 
            -
             | 
| 184 | 
            -
                
         | 
| 185 | 
            -
             | 
| 186 | 
            -
             | 
| 187 | 
            -
             | 
| 188 | 
            -
             | 
| 189 | 
            -
             | 
| 190 | 
            -
             | 
| 191 | 
            -
             | 
| 192 | 
            -
             | 
| 193 | 
            -
                
         | 
| 194 | 
            -
             | 
| 195 | 
            -
             | 
| 196 | 
            -
             | 
| 197 | 
            -
             | 
| 198 | 
            -
             | 
| 199 | 
            -
             | 
| 200 | 
            -
             | 
| 201 | 
            -
             | 
| 202 | 
             
            ]
         | 
| 203 |  | 
| 204 |  | 
| 205 | 
             
            POINTS = {
         | 
| 206 | 
            -
                 | 
| 207 | 
            -
                 | 
| 208 | 
            -
                 | 
| 209 | 
            -
                 | 
| 210 | 
            -
                 | 
| 211 | 
            -
                 | 
| 212 | 
             
            }
         | 
| 213 |  | 
| 214 | 
             
            IMAGE_PATH = {
         | 
| 215 | 
            -
                 | 
| 216 | 
            -
                 | 
| 217 | 
            -
                 | 
| 218 | 
            -
                 | 
| 219 | 
            -
                 | 
| 220 | 
            -
                 | 
| 221 | 
             
            }
         | 
| 222 |  | 
| 223 |  | 
| 224 | 
            -
             | 
| 225 | 
             
            DREAM_BOOTH = {
         | 
| 226 | 
            -
                 | 
| 227 | 
             
            }
         | 
| 228 |  | 
| 229 | 
             
            LORA = {
         | 
| 230 | 
            -
                 | 
| 231 | 
             
            }
         | 
| 232 |  | 
| 233 | 
             
            LORA_ALPHA = {
         | 
| 234 | 
            -
                 | 
| 235 | 
             
            }
         | 
| 236 |  | 
| 237 | 
             
            NPROMPT = {
         | 
| 238 | 
            -
                "HelloObject":  | 
| 239 | 
             
            }
         | 
| 240 |  | 
| 241 | 
             
            output_dir = "outputs"
         | 
| 242 | 
             
            ensure_dirname(output_dir)
         | 
| 243 |  | 
|  | |
| 244 | 
             
            def points_to_flows(track_points, model_length, height, width):
         | 
| 245 | 
             
                input_drag = np.zeros((model_length - 1, height, width, 2))
         | 
| 246 | 
             
                for splited_track in track_points:
         | 
| 247 | 
            -
                    if len(splited_track) == 1: | 
| 248 | 
             
                        displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
         | 
| 249 | 
             
                        splited_track = tuple([splited_track[0], displacement_point])
         | 
| 250 | 
             
                    # interpolate the track
         | 
| 251 | 
             
                    splited_track = interpolate_trajectory(splited_track, model_length)
         | 
| 252 | 
             
                    splited_track = splited_track[:model_length]
         | 
| 253 | 
             
                    if len(splited_track) < model_length:
         | 
| 254 | 
            -
                        splited_track = splited_track + [splited_track[-1]] * (model_length -len(splited_track))
         | 
| 255 | 
             
                    for i in range(model_length - 1):
         | 
| 256 | 
             
                        start_point = splited_track[i]
         | 
| 257 | 
            -
                        end_point = splited_track[i+1]
         | 
| 258 | 
             
                        input_drag[i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0]
         | 
| 259 | 
             
                        input_drag[i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1]
         | 
| 260 | 
             
                return input_drag
         | 
| 261 |  | 
|  | |
| 262 | 
             
            class ImageConductor:
         | 
| 263 | 
            -
                def __init__( | 
|  | |
|  | |
| 264 | 
             
                    self.device = device
         | 
| 265 | 
            -
                    tokenizer | 
| 266 | 
            -
                    text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").to( | 
| 267 | 
            -
             | 
|  | |
|  | |
| 268 | 
             
                    inference_config = OmegaConf.load("configs/inference/inference.yaml")
         | 
| 269 | 
            -
                    unet = UNet3DConditionFlowModel.from_pretrained_2d( | 
|  | |
|  | |
| 270 |  | 
| 271 | 
             
                    self.vae = vae
         | 
| 272 |  | 
| @@ -287,15 +266,14 @@ class ImageConductor: | |
| 287 |  | 
| 288 | 
             
                    self.pipeline = ImageConductorPipeline(
         | 
| 289 | 
             
                        unet=unet,
         | 
| 290 | 
            -
                        vae=vae, | 
| 291 | 
            -
                        tokenizer=tokenizer, | 
| 292 | 
            -
                        text_encoder=text_encoder, | 
| 293 | 
             
                        scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
         | 
| 294 | 
             
                        image_controlnet=image_controlnet,
         | 
| 295 | 
             
                        flow_controlnet=flow_controlnet,
         | 
| 296 | 
             
                    ).to(device)
         | 
| 297 |  | 
| 298 | 
            -
                    
         | 
| 299 | 
             
                    self.height = height
         | 
| 300 | 
             
                    self.width = width
         | 
| 301 | 
             
                    # _, model_step, _ = split_filename(model_path)
         | 
| @@ -307,40 +285,51 @@ class ImageConductor: | |
| 307 | 
             
                    self.blur_kernel = blur_kernel
         | 
| 308 |  | 
| 309 | 
             
                @spaces.GPU(duration=180)
         | 
| 310 | 
            -
                def run( | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 311 | 
             
                    print("Run!")
         | 
| 312 | 
            -
             | 
| 313 | 
            -
             | 
| 314 | 
            -
             | 
| 315 | 
            -
             | 
| 316 | 
            -
                        points = json.load(open(POINTS[examples_type]))
         | 
| 317 | 
            -
                        tracking_points.value.extend(points)
         | 
| 318 | 
            -
                        print("example first_frame_path", first_frame_path)
         | 
| 319 | 
            -
                        print("example tracking_points", tracking_points.value)
         | 
| 320 | 
            -
                        
         | 
| 321 | 
            -
                    original_width, original_height=384, 256
         | 
| 322 | 
            -
                    if isinstance(tracking_points, list):
         | 
| 323 | 
            -
                        input_all_points = tracking_points
         | 
| 324 | 
            -
                    else:
         | 
| 325 | 
            -
                        input_all_points = tracking_points.value
         | 
| 326 | 
            -
                    
         | 
| 327 | 
             
                    print("input_all_points", input_all_points)
         | 
| 328 | 
            -
                    resized_all_points = [ | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 329 |  | 
| 330 | 
             
                    dir, base, ext = split_filename(first_frame_path)
         | 
| 331 | 
            -
                    id = base.split( | 
| 332 | 
            -
             | 
| 333 | 
            -
                    
         | 
| 334 | 
            -
             | 
|  | |
| 335 |  | 
| 336 | 
            -
                    ## image condition | 
| 337 | 
            -
                    image_transforms = transforms.Compose( | 
|  | |
| 338 | 
             
                            transforms.RandomResizedCrop(
         | 
| 339 | 
            -
                                (self.height, self.width), (1.0, 1.0), 
         | 
| 340 | 
            -
                                ratio=(self.width/self.height, self.width/self.height)
         | 
| 341 | 
             
                            ),
         | 
| 342 | 
             
                            transforms.ToTensor(),
         | 
| 343 | 
            -
                        ] | 
|  | |
| 344 |  | 
| 345 | 
             
                    image_paths = [first_frame_path]
         | 
| 346 | 
             
                    controlnet_images = [(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths]
         | 
| @@ -349,205 +338,296 @@ class ImageConductor: | |
| 349 | 
             
                    num_controlnet_images = controlnet_images.shape[2]
         | 
| 350 | 
             
                    controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w")
         | 
| 351 | 
             
                    self.vae.to(device)
         | 
| 352 | 
            -
                    controlnet_images = self.vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215
         | 
| 353 | 
             
                    controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)
         | 
| 354 |  | 
| 355 | 
             
                    # flow condition
         | 
| 356 | 
             
                    controlnet_flows = points_to_flows(resized_all_points, self.model_length, self.height, self.width)
         | 
| 357 | 
            -
                    for i in range(0, self.model_length-1):
         | 
| 358 | 
             
                        controlnet_flows[i] = cv2.filter2D(controlnet_flows[i], -1, self.blur_kernel)
         | 
| 359 | 
            -
                    controlnet_flows = np.concatenate( | 
|  | |
|  | |
| 360 | 
             
                    os.makedirs(os.path.join(output_dir, "control_flows"), exist_ok=True)
         | 
| 361 | 
            -
                    trajs_video = vis_flow_to_video(controlnet_flows, num_frames=self.model_length) | 
| 362 | 
            -
                    torchvision.io.write_video( | 
| 363 | 
            -
             | 
| 364 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 365 |  | 
| 366 | 
            -
                    dreambooth_model_path = DREAM_BOOTH.get(personalized,  | 
| 367 | 
            -
                    lora_model_path = LORA.get(personalized,  | 
| 368 | 
             
                    lora_alpha = LORA_ALPHA.get(personalized, 0.6)
         | 
| 369 | 
             
                    self.pipeline = load_weights(
         | 
| 370 | 
             
                        self.pipeline,
         | 
| 371 | 
            -
                        dreambooth_model_path | 
| 372 | 
            -
                        lora_model_path | 
| 373 | 
            -
                        lora_alpha | 
| 374 | 
             
                    ).to(device)
         | 
| 375 | 
            -
             | 
| 376 | 
            -
                    if NPROMPT.get(personalized,  | 
| 377 | 
            -
                        negative_prompt = | 
| 378 | 
            -
             | 
| 379 | 
             
                    if randomize_seed:
         | 
| 380 | 
             
                        random_seed = torch.seed()
         | 
| 381 | 
             
                    else:
         | 
| 382 | 
             
                        seed = int(seed)
         | 
| 383 | 
             
                        random_seed = seed
         | 
| 384 | 
             
                    torch.manual_seed(random_seed)
         | 
| 385 | 
            -
                    torch.cuda.manual_seed_all(random_seed) | 
| 386 | 
             
                    print(f"current seed: {torch.initial_seed()}")
         | 
| 387 | 
             
                    sample = self.pipeline(
         | 
| 388 | 
            -
             | 
| 389 | 
            -
             | 
| 390 | 
            -
             | 
| 391 | 
            -
             | 
| 392 | 
            -
             | 
| 393 | 
            -
             | 
| 394 | 
            -
             | 
| 395 | 
            -
             | 
| 396 | 
            -
             | 
| 397 | 
            -
             | 
| 398 | 
            -
             | 
| 399 | 
            -
             | 
| 400 | 
            -
             | 
| 401 | 
            -
             | 
| 402 | 
            -
                    outputs_path = os.path.join(output_dir, f | 
| 403 | 
            -
                    vis_video = (rearrange(sample[0],  | 
| 404 | 
            -
                    torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec= | 
| 405 | 
            -
             | 
| 406 | 
             
                    # outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
         | 
| 407 | 
             
                    # save_videos_grid(sample[0][None], outputs_path)
         | 
| 408 | 
             
                    print("Done!")
         | 
| 409 | 
            -
                    return  | 
| 410 |  | 
| 411 |  | 
| 412 | 
             
            def reset_states(first_frame_path, tracking_points):
         | 
| 413 | 
            -
                first_frame_path =  | 
| 414 | 
            -
                tracking_points =  | 
| 415 | 
            -
                return {input_image:None, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
         | 
| 416 |  | 
| 417 |  | 
| 418 | 
             
            def preprocess_image(image, tracking_points):
         | 
| 419 | 
             
                image_pil = image2pil(image.name)
         | 
| 420 | 
             
                raw_w, raw_h = image_pil.size
         | 
| 421 | 
            -
                resize_ratio = max(384/raw_w, 256/raw_h)
         | 
| 422 | 
             
                image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
         | 
| 423 | 
            -
                image_pil = transforms.CenterCrop((256, 384))(image_pil.convert( | 
| 424 | 
             
                id = str(uuid.uuid4())[:4]
         | 
| 425 | 
             
                first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
         | 
| 426 | 
             
                image_pil.save(first_frame_path, quality=95)
         | 
| 427 | 
            -
                tracking_points = | 
| 428 | 
            -
                return { | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 429 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 430 |  | 
| 431 | 
            -
             | 
| 432 | 
            -
                if drag_mode=='object':
         | 
| 433 | 
             
                    color = (255, 0, 0, 255)
         | 
| 434 | 
            -
                elif drag_mode== | 
| 435 | 
             
                    color = (0, 0, 255, 255)
         | 
| 436 |  | 
| 437 | 
            -
                 | 
| 438 | 
            -
                    print(f"You selected {evt.value} at {evt.index} from {evt.target}")
         | 
| 439 | 
            -
                    tracking_points.value[-1].append(evt.index)
         | 
| 440 | 
            -
                    print(tracking_points.value)
         | 
| 441 | 
            -
                    tracking_points_values =  tracking_points.value
         | 
| 442 | 
            -
                else:
         | 
| 443 | 
            -
                    try:
         | 
| 444 | 
            -
                        tracking_points[-1].append(evt.index)
         | 
| 445 | 
            -
                    except Exception as e:
         | 
| 446 | 
            -
                        tracking_points.append([])
         | 
| 447 | 
            -
                        tracking_points[-1].append(evt.index)
         | 
| 448 | 
            -
                        print(f"Solved Error: {e}")
         | 
| 449 | 
            -
                    
         | 
| 450 | 
            -
                    tracking_points_values = tracking_points
         | 
| 451 | 
            -
                    
         | 
| 452 | 
            -
                
         | 
| 453 | 
            -
                transparent_background = Image.open(first_frame_path).convert('RGBA')
         | 
| 454 | 
             
                w, h = transparent_background.size
         | 
| 455 | 
             
                transparent_layer = np.zeros((h, w, 4))
         | 
| 456 | 
            -
             | 
| 457 | 
            -
                for track in  | 
| 458 | 
             
                    if len(track) > 1:
         | 
| 459 | 
            -
                        for i in range(len(track)-1):
         | 
| 460 | 
             
                            start_point = track[i]
         | 
| 461 | 
            -
                            end_point = track[i+1]
         | 
| 462 | 
             
                            vx = end_point[0] - start_point[0]
         | 
| 463 | 
             
                            vy = end_point[1] - start_point[1]
         | 
| 464 | 
             
                            arrow_length = np.sqrt(vx**2 + vy**2)
         | 
| 465 | 
            -
                            if i == len(track)-2:
         | 
| 466 | 
            -
                                cv2.arrowedLine( | 
|  | |
|  | |
| 467 | 
             
                            else:
         | 
| 468 | 
            -
                                cv2.line( | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 469 | 
             
                    else:
         | 
| 470 | 
             
                        cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
         | 
| 471 |  | 
| 472 | 
             
                transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
         | 
| 473 | 
             
                trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
         | 
| 474 | 
            -
             | 
|  | |
| 475 |  | 
| 476 |  | 
| 477 | 
             
            def add_drag(tracking_points):
         | 
| 478 | 
            -
                if not  | 
| 479 | 
            -
                    # print("before", tracking_points.value)
         | 
| 480 | 
            -
                    tracking_points.value.append([])
         | 
| 481 | 
            -
                    # print(tracking_points.value)
         | 
| 482 | 
            -
                else:
         | 
| 483 | 
             
                    tracking_points.append([])
         | 
| 484 | 
             
                return {tracking_points_var: tracking_points}
         | 
| 485 | 
            -
             | 
| 486 |  | 
| 487 | 
             
            def delete_last_drag(tracking_points, first_frame_path, drag_mode):
         | 
| 488 | 
            -
                if drag_mode== | 
| 489 | 
             
                    color = (255, 0, 0, 255)
         | 
| 490 | 
            -
                elif drag_mode== | 
| 491 | 
             
                    color = (0, 0, 255, 255)
         | 
| 492 | 
            -
                tracking_points | 
| 493 | 
            -
             | 
|  | |
| 494 | 
             
                w, h = transparent_background.size
         | 
| 495 | 
             
                transparent_layer = np.zeros((h, w, 4))
         | 
| 496 | 
            -
                for track in tracking_points | 
| 497 | 
             
                    if len(track) > 1:
         | 
| 498 | 
            -
                        for i in range(len(track)-1):
         | 
| 499 | 
             
                            start_point = track[i]
         | 
| 500 | 
            -
                            end_point = track[i+1]
         | 
| 501 | 
             
                            vx = end_point[0] - start_point[0]
         | 
| 502 | 
             
                            vy = end_point[1] - start_point[1]
         | 
| 503 | 
             
                            arrow_length = np.sqrt(vx**2 + vy**2)
         | 
| 504 | 
            -
                            if i == len(track)-2:
         | 
| 505 | 
            -
                                cv2.arrowedLine( | 
|  | |
|  | |
| 506 | 
             
                            else:
         | 
| 507 | 
            -
                                cv2.line( | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 508 | 
             
                    else:
         | 
| 509 | 
             
                        cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
         | 
| 510 |  | 
| 511 | 
             
                transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
         | 
| 512 | 
             
                trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
         | 
| 513 | 
             
                return {tracking_points_var: tracking_points, input_image: trajectory_map}
         | 
| 514 | 
            -
             | 
| 515 |  | 
| 516 | 
             
            def delete_last_step(tracking_points, first_frame_path, drag_mode):
         | 
| 517 | 
            -
                if drag_mode== | 
| 518 | 
             
                    color = (255, 0, 0, 255)
         | 
| 519 | 
            -
                elif drag_mode== | 
| 520 | 
             
                    color = (0, 0, 255, 255)
         | 
| 521 | 
            -
                tracking_points | 
| 522 | 
            -
             | 
|  | |
| 523 | 
             
                w, h = transparent_background.size
         | 
| 524 | 
             
                transparent_layer = np.zeros((h, w, 4))
         | 
| 525 | 
            -
                for track in tracking_points | 
|  | |
|  | |
| 526 | 
             
                    if len(track) > 1:
         | 
| 527 | 
            -
                        for i in range(len(track)-1):
         | 
| 528 | 
             
                            start_point = track[i]
         | 
| 529 | 
            -
                            end_point = track[i+1]
         | 
| 530 | 
             
                            vx = end_point[0] - start_point[0]
         | 
| 531 | 
             
                            vy = end_point[1] - start_point[1]
         | 
| 532 | 
             
                            arrow_length = np.sqrt(vx**2 + vy**2)
         | 
| 533 | 
            -
                            if i == len(track)-2:
         | 
| 534 | 
            -
                                cv2.arrowedLine( | 
|  | |
|  | |
| 535 | 
             
                            else:
         | 
| 536 | 
            -
                                cv2.line( | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 537 | 
             
                    else:
         | 
| 538 | 
            -
                        cv2.circle(transparent_layer, tuple(track[0]), 5,color, -1)
         | 
| 539 |  | 
| 540 | 
             
                transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
         | 
| 541 | 
             
                trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
         | 
| 542 | 
             
                return {tracking_points_var: tracking_points, input_image: trajectory_map}
         | 
| 543 |  | 
| 544 |  | 
| 545 | 
            -
             | 
| 546 | 
            -
             | 
| 547 | 
            -
             | 
| 548 | 
            -
             | 
| 549 | 
            -
             | 
| 550 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 551 | 
             
            with block:
         | 
| 552 | 
             
                with gr.Row():
         | 
| 553 | 
             
                    with gr.Column():
         | 
| @@ -557,66 +637,58 @@ with block: | |
| 557 |  | 
| 558 | 
             
                with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"):
         | 
| 559 | 
             
                    with gr.Row(equal_height=True):
         | 
| 560 | 
            -
                        gr.Markdown(instructions) | 
| 561 | 
            -
             | 
| 562 | 
            -
             | 
| 563 | 
            -
                # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
         | 
| 564 | 
            -
                device = torch.device("cuda")
         | 
| 565 | 
            -
                unet_path = 'models/unet.ckpt'
         | 
| 566 | 
            -
                image_controlnet_path = 'models/image_controlnet.ckpt'
         | 
| 567 | 
            -
                flow_controlnet_path = 'models/flow_controlnet.ckpt'
         | 
| 568 | 
            -
                ImageConductor_net = ImageConductor(device=device, 
         | 
| 569 | 
            -
                                                    unet_path=unet_path, 
         | 
| 570 | 
            -
                                                    image_controlnet_path=image_controlnet_path, 
         | 
| 571 | 
            -
                                                    flow_controlnet_path=flow_controlnet_path, 
         | 
| 572 | 
            -
                                                    height=256,
         | 
| 573 | 
            -
                                                    width=384,
         | 
| 574 | 
            -
                                                    model_length=16
         | 
| 575 | 
            -
                                                    )
         | 
| 576 | 
            -
                first_frame_path_var = gr.State(value=None)
         | 
| 577 | 
             
                tracking_points_var = gr.State([])
         | 
| 578 |  | 
| 579 | 
             
                with gr.Row():
         | 
| 580 | 
             
                    with gr.Column(scale=1):
         | 
| 581 | 
            -
                        image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
         | 
| 582 | 
             
                        add_drag_button = gr.Button(value="Add Drag")
         | 
| 583 | 
             
                        reset_button = gr.Button(value="Reset")
         | 
| 584 | 
             
                        delete_last_drag_button = gr.Button(value="Delete last drag")
         | 
| 585 | 
             
                        delete_last_step_button = gr.Button(value="Delete last step")
         | 
| 586 | 
            -
                        
         | 
| 587 | 
            -
                        
         | 
| 588 |  | 
| 589 | 
             
                    with gr.Column(scale=7):
         | 
| 590 | 
             
                        with gr.Row():
         | 
| 591 | 
             
                            with gr.Column(scale=6):
         | 
| 592 | 
            -
                                input_image = gr.Image( | 
| 593 | 
            -
             | 
| 594 | 
            -
             | 
| 595 | 
            -
             | 
|  | |
|  | |
| 596 | 
             
                            with gr.Column(scale=6):
         | 
| 597 | 
            -
                                output_image = gr.Image( | 
| 598 | 
            -
             | 
| 599 | 
            -
             | 
| 600 | 
            -
             | 
|  | |
|  | |
| 601 | 
             
                with gr.Row():
         | 
| 602 | 
             
                    with gr.Column(scale=1):
         | 
| 603 | 
            -
                        prompt = gr.Textbox( | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 604 | 
             
                        negative_prompt = gr.Text(
         | 
| 605 | 
            -
             | 
| 606 | 
            -
             | 
| 607 | 
            -
             | 
| 608 | 
            -
             | 
| 609 | 
            -
             | 
| 610 | 
            -
                         | 
|  | |
| 611 | 
             
                        run_button = gr.Button(value="Run")
         | 
| 612 |  | 
| 613 | 
             
                        with gr.Accordion("More input params", open=False, elem_id="accordion1"):
         | 
| 614 | 
             
                            with gr.Group():
         | 
| 615 | 
            -
                                seed = gr.Textbox(
         | 
| 616 | 
            -
                                    label="Seed: ",  value=561793204,
         | 
| 617 | 
            -
                                )
         | 
| 618 | 
             
                                randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
         | 
| 619 | 
            -
             | 
| 620 | 
             
                            with gr.Group():
         | 
| 621 | 
             
                                with gr.Row():
         | 
| 622 | 
             
                                    guidance_scale = gr.Slider(
         | 
| @@ -633,24 +705,15 @@ with block: | |
| 633 | 
             
                                        step=1,
         | 
| 634 | 
             
                                        value=25,
         | 
| 635 | 
             
                                    )
         | 
| 636 | 
            -
             | 
| 637 | 
             
                            with gr.Group():
         | 
| 638 | 
            -
                                personalized = gr.Dropdown(label="Personalized", choices=["",  | 
| 639 | 
            -
                                examples_type = gr.Textbox(label="Examples Type (Ignore) ", | 
| 640 |  | 
| 641 | 
             
                    with gr.Column(scale=7):
         | 
| 642 | 
            -
                        output_video = gr.Video(
         | 
| 643 | 
            -
                                                label="Output Video", 
         | 
| 644 | 
            -
                                                width=384, 
         | 
| 645 | 
            -
                                                height=256)
         | 
| 646 | 
            -
                        # output_video = gr.Image(label="Output Video",
         | 
| 647 | 
            -
                        #                                 height=256,
         | 
| 648 | 
            -
                        #                                 width=384,)
         | 
| 649 | 
            -
                        
         | 
| 650 | 
            -
                        
         | 
| 651 | 
            -
                with gr.Row():
         | 
| 652 | 
            -
               
         | 
| 653 |  | 
|  | |
| 654 | 
             
                    example = gr.Examples(
         | 
| 655 | 
             
                        label="Input Example",
         | 
| 656 | 
             
                        examples=image_examples,
         | 
| @@ -658,26 +721,65 @@ with block: | |
| 658 | 
             
                        examples_per_page=10,
         | 
| 659 | 
             
                        cache_examples=False,
         | 
| 660 | 
             
                    )
         | 
| 661 | 
            -
             | 
| 662 | 
            -
                    
         | 
| 663 | 
             
                with gr.Row():
         | 
| 664 | 
             
                    gr.Markdown(citation)
         | 
| 665 |  | 
| 666 | 
            -
                
         | 
| 667 | 
            -
             | 
|  | |
|  | |
|  | |
| 668 |  | 
| 669 | 
             
                add_drag_button.click(add_drag, tracking_points_var, tracking_points_var)
         | 
| 670 |  | 
| 671 | 
            -
                delete_last_drag_button.click( | 
| 672 | 
            -
             | 
| 673 | 
            -
             | 
| 674 | 
            -
             | 
| 675 | 
            -
                 | 
| 676 | 
            -
             | 
| 677 | 
            -
                 | 
| 678 | 
            -
             | 
| 679 | 
            -
             | 
| 680 | 
            -
             | 
| 681 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 682 |  | 
| 683 | 
             
            block.queue().launch()
         | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
             
            import os
         | 
| 3 | 
            +
            import uuid
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 4 |  | 
| 5 | 
            +
            import cv2
         | 
| 6 | 
             
            import gradio as gr
         | 
| 7 | 
             
            import numpy as np
         | 
| 8 | 
            +
            import spaces
         | 
|  | |
| 9 | 
             
            import torch
         | 
| 10 | 
             
            import torchvision
         | 
| 11 | 
            +
            from diffusers import AutoencoderKL, DDIMScheduler
         | 
| 12 | 
            +
            from einops import rearrange
         | 
| 13 | 
            +
            from huggingface_hub import hf_hub_download
         | 
|  | |
| 14 | 
             
            from omegaconf import OmegaConf
         | 
| 15 | 
            +
            from PIL import Image
         | 
| 16 | 
            +
            from torchvision import transforms
         | 
| 17 | 
             
            from transformers import CLIPTextModel, CLIPTokenizer
         | 
|  | |
| 18 |  | 
|  | |
| 19 | 
             
            from modules.unet import UNet3DConditionFlowModel
         | 
| 20 | 
            +
            from pipelines.pipeline_imagecoductor import ImageConductorPipeline
         | 
| 21 | 
            +
            from utils.gradio_utils import ensure_dirname, image2pil, split_filename, visualize_drag
         | 
| 22 | 
             
            from utils.lora_utils import add_LoRA_to_controlnet
         | 
| 23 | 
            +
            from utils.utils import (
         | 
| 24 | 
            +
                bivariate_Gaussian,
         | 
| 25 | 
            +
                create_flow_controlnet,
         | 
| 26 | 
            +
                create_image_controlnet,
         | 
| 27 | 
            +
                interpolate_trajectory,
         | 
| 28 | 
            +
                load_model,
         | 
| 29 | 
            +
                load_weights,
         | 
| 30 | 
            +
            )
         | 
| 31 | 
            +
            from utils.visualizer import vis_flow_to_video
         | 
| 32 | 
            +
             | 
| 33 | 
             
            #### Description ####
         | 
| 34 | 
             
            title = r"""<h1 align="center">CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models</h1>"""
         | 
| 35 |  | 
|  | |
| 41 | 
             
                                        <a href='https://liyaowei-stu.github.io/project/ImageConductor/'><img src='https://img.shields.io/badge/Project_Page-ImgaeConductor-green' alt='Project Page'></a>
         | 
| 42 | 
             
                                        <a href='https://arxiv.org/pdf/2406.15339'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
         | 
| 43 | 
             
                                        <a href='https://github.com/liyaowei-stu/ImageConductor'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
         | 
| 44 | 
            +
             | 
| 45 |  | 
| 46 | 
             
                                    </div>
         | 
| 47 | 
             
                                    </br>
         | 
|  | |
| 49 | 
             
            """
         | 
| 50 |  | 
| 51 |  | 
|  | |
| 52 | 
             
            descriptions = r"""
         | 
| 53 | 
             
            Official Gradio Demo for <a href='https://github.com/liyaowei-stu/ImageConductor'><b>Image Conductor: Precision Control for Interactive Video Synthesis</b></a>.<br>
         | 
| 54 | 
             
            🧙Image Conductor enables precise, fine-grained control for generating motion-controllable videos from images, advancing the practical application of interactive video synthesis.<br>
         | 
|  | |
| 65 | 
             
                        """
         | 
| 66 |  | 
| 67 | 
             
            citation = r"""
         | 
| 68 | 
            +
            If Image Conductor is helpful, please help to ⭐ the <a href='https://github.com/liyaowei-stu/ImageConductor' target='_blank'>Github Repo</a>. Thanks!
         | 
| 69 | 
             
            [](https://github.com/liyaowei-stu/ImageConductor)
         | 
| 70 | 
             
            ---
         | 
| 71 |  | 
|  | |
| 74 | 
             
            If our work is useful for your research, please consider citing:
         | 
| 75 | 
             
            ```bibtex
         | 
| 76 | 
             
            @misc{li2024imageconductor,
         | 
| 77 | 
            +
                title={Image Conductor: Precision Control for Interactive Video Synthesis},
         | 
| 78 | 
             
                author={Li, Yaowei and Wang, Xintao and Zhang, Zhaoyang and Wang, Zhouxia and Yuan, Ziyang and Xie, Liangbin and Zou, Yuexian and Shan, Ying},
         | 
| 79 | 
             
                year={2024},
         | 
| 80 | 
             
                eprint={2406.15339},
         | 
|  | |
| 89 |  | 
| 90 | 
             
            # """
         | 
| 91 |  | 
| 92 | 
            +
            flow_controlnet_path = hf_hub_download("TencentARC/ImageConductor", "flow_controlnet.ckpt")
         | 
| 93 | 
            +
            image_controlnet_path = hf_hub_download("TencentARC/ImageConductor", "image_controlnet.ckpt")
         | 
| 94 | 
            +
            unet_path = hf_hub_download("TencentARC/ImageConductor", "unet.ckpt")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 95 |  | 
| 96 | 
            +
            helloobjects_path = hf_hub_download("TencentARC/ImageConductor", "helloobjects_V12c.safetensors")
         | 
| 97 | 
            +
            tusun_path = hf_hub_download("TencentARC/ImageConductor", "TUSUN.safetensors")
         | 
|  | |
|  | |
| 98 |  | 
| 99 | 
            +
            os.makedirs("models/sd1-5", exist_ok=True)
         | 
| 100 | 
            +
            sd15_config_path = hf_hub_download("runwayml/stable-diffusion-v1-5", "config.json", subfolder="unet")
         | 
| 101 | 
             
            if not os.path.exists("models/sd1-5/config.json"):
         | 
| 102 | 
            +
                os.symlink(sd15_config_path, "models/sd1-5/config.json")
         | 
|  | |
|  | |
|  | |
|  | |
| 103 | 
             
            if not os.path.exists("models/sd1-5/unet.ckpt"):
         | 
| 104 | 
            +
                os.symlink(unet_path, "models/sd1-5/unet.ckpt")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 105 |  | 
| 106 | 
             
            # mv1 = os.system(f'mv /usr/local/lib/python3.10/site-packages/gradio/helpers.py /usr/local/lib/python3.10/site-packages/gradio/helpers_bkp.py')
         | 
| 107 | 
             
            # mv2 = os.system(f'mv helpers.py /usr/local/lib/python3.10/site-packages/gradio/helpers.py')
         | 
|  | |
| 117 | 
             
            # - - - - - examples  - - - - -  #
         | 
| 118 |  | 
| 119 | 
             
            image_examples = [
         | 
| 120 | 
            +
                [
         | 
| 121 | 
            +
                    "__asset__/images/object/turtle-1.jpg",
         | 
| 122 | 
            +
                    "a sea turtle gracefully swimming over a coral reef in the clear blue ocean.",
         | 
| 123 | 
            +
                    "object",
         | 
| 124 | 
            +
                    11318446767408804497,
         | 
| 125 | 
            +
                    "",
         | 
| 126 | 
            +
                    "turtle",
         | 
| 127 | 
            +
                    "__asset__/turtle.mp4",
         | 
| 128 | 
            +
                ],
         | 
| 129 | 
            +
                [
         | 
| 130 | 
            +
                    "__asset__/images/object/rose-1.jpg",
         | 
| 131 | 
            +
                    "a red rose engulfed in flames.",
         | 
| 132 | 
            +
                    "object",
         | 
| 133 | 
            +
                    6854275249656120509,
         | 
| 134 | 
            +
                    "",
         | 
| 135 | 
            +
                    "rose",
         | 
| 136 | 
            +
                    "__asset__/rose.mp4",
         | 
| 137 | 
            +
                ],
         | 
| 138 | 
            +
                [
         | 
| 139 | 
            +
                    "__asset__/images/object/jellyfish-1.jpg",
         | 
| 140 | 
            +
                    "intricate detailing,photorealism,hyperrealistic, glowing jellyfish mushroom, flying, starry sky, bokeh, golden ratio composition.",
         | 
| 141 | 
            +
                    "object",
         | 
| 142 | 
            +
                    17966188172968903484,
         | 
| 143 | 
            +
                    "HelloObject",
         | 
| 144 | 
            +
                    "jellyfish",
         | 
| 145 | 
            +
                    "__asset__/jellyfish.mp4",
         | 
| 146 | 
            +
                ],
         | 
| 147 | 
            +
                [
         | 
| 148 | 
            +
                    "__asset__/images/camera/lush-1.jpg",
         | 
| 149 | 
            +
                    "detailed craftsmanship, photorealism, hyperrealistic, roaring waterfall, misty spray, lush greenery, vibrant rainbow, golden ratio composition.",
         | 
| 150 | 
            +
                    "camera",
         | 
| 151 | 
            +
                    7970487946960948963,
         | 
| 152 | 
            +
                    "HelloObject",
         | 
| 153 | 
            +
                    "lush",
         | 
| 154 | 
            +
                    "__asset__/lush.mp4",
         | 
| 155 | 
            +
                ],
         | 
| 156 | 
            +
                [
         | 
| 157 | 
            +
                    "__asset__/images/camera/tusun-1.jpg",
         | 
| 158 | 
            +
                    "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing.",
         | 
| 159 | 
            +
                    "camera",
         | 
| 160 | 
            +
                    996953226890228361,
         | 
| 161 | 
            +
                    "TUSUN",
         | 
| 162 | 
            +
                    "tusun",
         | 
| 163 | 
            +
                    "__asset__/tusun.mp4",
         | 
| 164 | 
            +
                ],
         | 
| 165 | 
            +
                [
         | 
| 166 | 
            +
                    "__asset__/images/camera/painting-1.jpg",
         | 
| 167 | 
            +
                    "A oil painting.",
         | 
| 168 | 
            +
                    "camera",
         | 
| 169 | 
            +
                    16867854766769816385,
         | 
| 170 | 
            +
                    "",
         | 
| 171 | 
            +
                    "painting",
         | 
| 172 | 
            +
                    "__asset__/painting.mp4",
         | 
| 173 | 
            +
                ],
         | 
| 174 | 
             
            ]
         | 
| 175 |  | 
| 176 |  | 
| 177 | 
             
            POINTS = {
         | 
| 178 | 
            +
                "turtle": "__asset__/trajs/object/turtle-1.json",
         | 
| 179 | 
            +
                "rose": "__asset__/trajs/object/rose-1.json",
         | 
| 180 | 
            +
                "jellyfish": "__asset__/trajs/object/jellyfish-1.json",
         | 
| 181 | 
            +
                "lush": "__asset__/trajs/camera/lush-1.json",
         | 
| 182 | 
            +
                "tusun": "__asset__/trajs/camera/tusun-1.json",
         | 
| 183 | 
            +
                "painting": "__asset__/trajs/camera/painting-1.json",
         | 
| 184 | 
             
            }
         | 
| 185 |  | 
| 186 | 
             
            IMAGE_PATH = {
         | 
| 187 | 
            +
                "turtle": "__asset__/images/object/turtle-1.jpg",
         | 
| 188 | 
            +
                "rose": "__asset__/images/object/rose-1.jpg",
         | 
| 189 | 
            +
                "jellyfish": "__asset__/images/object/jellyfish-1.jpg",
         | 
| 190 | 
            +
                "lush": "__asset__/images/camera/lush-1.jpg",
         | 
| 191 | 
            +
                "tusun": "__asset__/images/camera/tusun-1.jpg",
         | 
| 192 | 
            +
                "painting": "__asset__/images/camera/painting-1.jpg",
         | 
| 193 | 
             
            }
         | 
| 194 |  | 
| 195 |  | 
|  | |
| 196 | 
             
            DREAM_BOOTH = {
         | 
| 197 | 
            +
                "HelloObject": helloobjects_path,
         | 
| 198 | 
             
            }
         | 
| 199 |  | 
| 200 | 
             
            LORA = {
         | 
| 201 | 
            +
                "TUSUN": tusun_path,
         | 
| 202 | 
             
            }
         | 
| 203 |  | 
| 204 | 
             
            LORA_ALPHA = {
         | 
| 205 | 
            +
                "TUSUN": 0.6,
         | 
| 206 | 
             
            }
         | 
| 207 |  | 
| 208 | 
             
            NPROMPT = {
         | 
| 209 | 
            +
                "HelloObject": "FastNegativeV2,(bad-artist:1),(worst quality, low quality:1.4),(bad_prompt_version2:0.8),bad-hands-5,lowres,bad anatomy,bad hands,((text)),(watermark),error,missing fingers,extra digit,fewer digits,cropped,worst quality,low quality,normal quality,((username)),blurry,(extra limbs),bad-artist-anime,badhandv4,EasyNegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3,BadDream,(three hands:1.6),(three legs:1.2),(more than two hands:1.4),(more than two legs,:1.2)"
         | 
| 210 | 
             
            }
         | 
| 211 |  | 
| 212 | 
             
            output_dir = "outputs"
         | 
| 213 | 
             
            ensure_dirname(output_dir)
         | 
| 214 |  | 
| 215 | 
            +
             | 
| 216 | 
             
            def points_to_flows(track_points, model_length, height, width):
         | 
| 217 | 
             
                input_drag = np.zeros((model_length - 1, height, width, 2))
         | 
| 218 | 
             
                for splited_track in track_points:
         | 
| 219 | 
            +
                    if len(splited_track) == 1:  # stationary point
         | 
| 220 | 
             
                        displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
         | 
| 221 | 
             
                        splited_track = tuple([splited_track[0], displacement_point])
         | 
| 222 | 
             
                    # interpolate the track
         | 
| 223 | 
             
                    splited_track = interpolate_trajectory(splited_track, model_length)
         | 
| 224 | 
             
                    splited_track = splited_track[:model_length]
         | 
| 225 | 
             
                    if len(splited_track) < model_length:
         | 
| 226 | 
            +
                        splited_track = splited_track + [splited_track[-1]] * (model_length - len(splited_track))
         | 
| 227 | 
             
                    for i in range(model_length - 1):
         | 
| 228 | 
             
                        start_point = splited_track[i]
         | 
| 229 | 
            +
                        end_point = splited_track[i + 1]
         | 
| 230 | 
             
                        input_drag[i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0]
         | 
| 231 | 
             
                        input_drag[i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1]
         | 
| 232 | 
             
                return input_drag
         | 
| 233 |  | 
| 234 | 
            +
             | 
| 235 | 
             
            class ImageConductor:
         | 
| 236 | 
            +
                def __init__(
         | 
| 237 | 
            +
                    self, device, unet_path, image_controlnet_path, flow_controlnet_path, height, width, model_length, lora_rank=64
         | 
| 238 | 
            +
                ):
         | 
| 239 | 
             
                    self.device = device
         | 
| 240 | 
            +
                    tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
         | 
| 241 | 
            +
                    text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").to(
         | 
| 242 | 
            +
                        device
         | 
| 243 | 
            +
                    )
         | 
| 244 | 
            +
                    vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to(device)
         | 
| 245 | 
             
                    inference_config = OmegaConf.load("configs/inference/inference.yaml")
         | 
| 246 | 
            +
                    unet = UNet3DConditionFlowModel.from_pretrained_2d(
         | 
| 247 | 
            +
                        "models/sd1-5/", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)
         | 
| 248 | 
            +
                    )
         | 
| 249 |  | 
| 250 | 
             
                    self.vae = vae
         | 
| 251 |  | 
|  | |
| 266 |  | 
| 267 | 
             
                    self.pipeline = ImageConductorPipeline(
         | 
| 268 | 
             
                        unet=unet,
         | 
| 269 | 
            +
                        vae=vae,
         | 
| 270 | 
            +
                        tokenizer=tokenizer,
         | 
| 271 | 
            +
                        text_encoder=text_encoder,
         | 
| 272 | 
             
                        scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
         | 
| 273 | 
             
                        image_controlnet=image_controlnet,
         | 
| 274 | 
             
                        flow_controlnet=flow_controlnet,
         | 
| 275 | 
             
                    ).to(device)
         | 
| 276 |  | 
|  | |
| 277 | 
             
                    self.height = height
         | 
| 278 | 
             
                    self.width = width
         | 
| 279 | 
             
                    # _, model_step, _ = split_filename(model_path)
         | 
|  | |
| 285 | 
             
                    self.blur_kernel = blur_kernel
         | 
| 286 |  | 
| 287 | 
             
                @spaces.GPU(duration=180)
         | 
| 288 | 
            +
                def run(
         | 
| 289 | 
            +
                    self,
         | 
| 290 | 
            +
                    first_frame_path,
         | 
| 291 | 
            +
                    tracking_points,
         | 
| 292 | 
            +
                    prompt,
         | 
| 293 | 
            +
                    drag_mode,
         | 
| 294 | 
            +
                    negative_prompt,
         | 
| 295 | 
            +
                    seed,
         | 
| 296 | 
            +
                    randomize_seed,
         | 
| 297 | 
            +
                    guidance_scale,
         | 
| 298 | 
            +
                    num_inference_steps,
         | 
| 299 | 
            +
                    personalized,
         | 
| 300 | 
            +
                ):
         | 
| 301 | 
             
                    print("Run!")
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    original_width, original_height = 384, 256
         | 
| 304 | 
            +
                    input_all_points = tracking_points
         | 
| 305 | 
            +
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 306 | 
             
                    print("input_all_points", input_all_points)
         | 
| 307 | 
            +
                    resized_all_points = [
         | 
| 308 | 
            +
                        tuple(
         | 
| 309 | 
            +
                            [
         | 
| 310 | 
            +
                                tuple([float(e1[0] * self.width / original_width), float(e1[1] * self.height / original_height)])
         | 
| 311 | 
            +
                                for e1 in e
         | 
| 312 | 
            +
                            ]
         | 
| 313 | 
            +
                        )
         | 
| 314 | 
            +
                        for e in input_all_points
         | 
| 315 | 
            +
                    ]
         | 
| 316 |  | 
| 317 | 
             
                    dir, base, ext = split_filename(first_frame_path)
         | 
| 318 | 
            +
                    id = base.split("_")[-1]
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    visualized_drag, _ = visualize_drag(
         | 
| 321 | 
            +
                        first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length
         | 
| 322 | 
            +
                    )
         | 
| 323 |  | 
| 324 | 
            +
                    ## image condition
         | 
| 325 | 
            +
                    image_transforms = transforms.Compose(
         | 
| 326 | 
            +
                        [
         | 
| 327 | 
             
                            transforms.RandomResizedCrop(
         | 
| 328 | 
            +
                                (self.height, self.width), (1.0, 1.0), ratio=(self.width / self.height, self.width / self.height)
         | 
|  | |
| 329 | 
             
                            ),
         | 
| 330 | 
             
                            transforms.ToTensor(),
         | 
| 331 | 
            +
                        ]
         | 
| 332 | 
            +
                    )
         | 
| 333 |  | 
| 334 | 
             
                    image_paths = [first_frame_path]
         | 
| 335 | 
             
                    controlnet_images = [(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths]
         | 
|  | |
| 338 | 
             
                    num_controlnet_images = controlnet_images.shape[2]
         | 
| 339 | 
             
                    controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w")
         | 
| 340 | 
             
                    self.vae.to(device)
         | 
| 341 | 
            +
                    controlnet_images = self.vae.encode(controlnet_images * 2.0 - 1.0).latent_dist.sample() * 0.18215
         | 
| 342 | 
             
                    controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)
         | 
| 343 |  | 
| 344 | 
             
                    # flow condition
         | 
| 345 | 
             
                    controlnet_flows = points_to_flows(resized_all_points, self.model_length, self.height, self.width)
         | 
| 346 | 
            +
                    for i in range(0, self.model_length - 1):
         | 
| 347 | 
             
                        controlnet_flows[i] = cv2.filter2D(controlnet_flows[i], -1, self.blur_kernel)
         | 
| 348 | 
            +
                    controlnet_flows = np.concatenate(
         | 
| 349 | 
            +
                        [np.zeros_like(controlnet_flows[0])[np.newaxis, ...], controlnet_flows], axis=0
         | 
| 350 | 
            +
                    )  # pad the first frame with zero flow
         | 
| 351 | 
             
                    os.makedirs(os.path.join(output_dir, "control_flows"), exist_ok=True)
         | 
| 352 | 
            +
                    trajs_video = vis_flow_to_video(controlnet_flows, num_frames=self.model_length)  # T-1 x H x W x 3
         | 
| 353 | 
            +
                    torchvision.io.write_video(
         | 
| 354 | 
            +
                        f"{output_dir}/control_flows/sample-{id}-train_flow.mp4",
         | 
| 355 | 
            +
                        trajs_video,
         | 
| 356 | 
            +
                        fps=8,
         | 
| 357 | 
            +
                        video_codec="h264",
         | 
| 358 | 
            +
                        options={"crf": "10"},
         | 
| 359 | 
            +
                    )
         | 
| 360 | 
            +
                    controlnet_flows = torch.from_numpy(controlnet_flows)[None][:, : self.model_length, ...]
         | 
| 361 | 
            +
                    controlnet_flows = rearrange(controlnet_flows, "b f h w c-> b c f h w").float().to(device)
         | 
| 362 |  | 
| 363 | 
            +
                    dreambooth_model_path = DREAM_BOOTH.get(personalized, "")
         | 
| 364 | 
            +
                    lora_model_path = LORA.get(personalized, "")
         | 
| 365 | 
             
                    lora_alpha = LORA_ALPHA.get(personalized, 0.6)
         | 
| 366 | 
             
                    self.pipeline = load_weights(
         | 
| 367 | 
             
                        self.pipeline,
         | 
| 368 | 
            +
                        dreambooth_model_path=dreambooth_model_path,
         | 
| 369 | 
            +
                        lora_model_path=lora_model_path,
         | 
| 370 | 
            +
                        lora_alpha=lora_alpha,
         | 
| 371 | 
             
                    ).to(device)
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                    if NPROMPT.get(personalized, "") != "":
         | 
| 374 | 
            +
                        negative_prompt = NPROMPT.get(personalized)
         | 
| 375 | 
            +
             | 
| 376 | 
             
                    if randomize_seed:
         | 
| 377 | 
             
                        random_seed = torch.seed()
         | 
| 378 | 
             
                    else:
         | 
| 379 | 
             
                        seed = int(seed)
         | 
| 380 | 
             
                        random_seed = seed
         | 
| 381 | 
             
                    torch.manual_seed(random_seed)
         | 
| 382 | 
            +
                    torch.cuda.manual_seed_all(random_seed)
         | 
| 383 | 
             
                    print(f"current seed: {torch.initial_seed()}")
         | 
| 384 | 
             
                    sample = self.pipeline(
         | 
| 385 | 
            +
                        prompt,
         | 
| 386 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 387 | 
            +
                        num_inference_steps=num_inference_steps,
         | 
| 388 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 389 | 
            +
                        width=self.width,
         | 
| 390 | 
            +
                        height=self.height,
         | 
| 391 | 
            +
                        video_length=self.model_length,
         | 
| 392 | 
            +
                        controlnet_images=controlnet_images,  # 1 4 1 32 48
         | 
| 393 | 
            +
                        controlnet_image_index=[0],
         | 
| 394 | 
            +
                        controlnet_flows=controlnet_flows,  # [1, 2, 16, 256, 384]
         | 
| 395 | 
            +
                        control_mode=drag_mode,
         | 
| 396 | 
            +
                        eval_mode=True,
         | 
| 397 | 
            +
                    ).videos
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                    outputs_path = os.path.join(output_dir, f"output_{i}_{id}.mp4")
         | 
| 400 | 
            +
                    vis_video = (rearrange(sample[0], "c t h w -> t h w c") * 255.0).clip(0, 255)
         | 
| 401 | 
            +
                    torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec="h264", options={"crf": "10"})
         | 
| 402 | 
            +
             | 
| 403 | 
             
                    # outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
         | 
| 404 | 
             
                    # save_videos_grid(sample[0][None], outputs_path)
         | 
| 405 | 
             
                    print("Done!")
         | 
| 406 | 
            +
                    return visualized_drag, outputs_path
         | 
| 407 |  | 
| 408 |  | 
| 409 | 
             
            def reset_states(first_frame_path, tracking_points):
         | 
| 410 | 
            +
                first_frame_path = None
         | 
| 411 | 
            +
                tracking_points = []
         | 
| 412 | 
            +
                return {input_image: None, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
         | 
| 413 |  | 
| 414 |  | 
| 415 | 
             
            def preprocess_image(image, tracking_points):
         | 
| 416 | 
             
                image_pil = image2pil(image.name)
         | 
| 417 | 
             
                raw_w, raw_h = image_pil.size
         | 
| 418 | 
            +
                resize_ratio = max(384 / raw_w, 256 / raw_h)
         | 
| 419 | 
             
                image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
         | 
| 420 | 
            +
                image_pil = transforms.CenterCrop((256, 384))(image_pil.convert("RGB"))
         | 
| 421 | 
             
                id = str(uuid.uuid4())[:4]
         | 
| 422 | 
             
                first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
         | 
| 423 | 
             
                image_pil.save(first_frame_path, quality=95)
         | 
| 424 | 
            +
                tracking_points = []
         | 
| 425 | 
            +
                return {
         | 
| 426 | 
            +
                    input_image: first_frame_path,
         | 
| 427 | 
            +
                    first_frame_path_var: first_frame_path,
         | 
| 428 | 
            +
                    tracking_points_var: tracking_points,
         | 
| 429 | 
            +
                    personalized: "",
         | 
| 430 | 
            +
                }
         | 
| 431 | 
            +
             | 
| 432 | 
            +
             | 
| 433 | 
            +
            def add_tracking_points(
         | 
| 434 | 
            +
                tracking_points, first_frame_path, drag_mode, evt: gr.SelectData
         | 
| 435 | 
            +
            ):  # SelectData is a subclass of EventData
         | 
| 436 | 
            +
                if drag_mode == "object":
         | 
| 437 | 
            +
                    color = (255, 0, 0, 255)
         | 
| 438 | 
            +
                elif drag_mode == "camera":
         | 
| 439 | 
            +
                    color = (0, 0, 255, 255)
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                print(f"You selected {evt.value} at {evt.index} from {evt.target}")
         | 
| 442 | 
            +
                if not tracking_points:
         | 
| 443 | 
            +
                    tracking_points = [[]]
         | 
| 444 | 
            +
                tracking_points[-1].append(evt.index)
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                transparent_background = Image.open(first_frame_path).convert("RGBA")
         | 
| 447 | 
            +
                w, h = transparent_background.size
         | 
| 448 | 
            +
                transparent_layer = np.zeros((h, w, 4))
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                for track in tracking_points:
         | 
| 451 | 
            +
                    if len(track) > 1:
         | 
| 452 | 
            +
                        for i in range(len(track) - 1):
         | 
| 453 | 
            +
                            start_point = track[i]
         | 
| 454 | 
            +
                            end_point = track[i + 1]
         | 
| 455 | 
            +
                            vx = end_point[0] - start_point[0]
         | 
| 456 | 
            +
                            vy = end_point[1] - start_point[1]
         | 
| 457 | 
            +
                            arrow_length = np.sqrt(vx**2 + vy**2)
         | 
| 458 | 
            +
                            if i == len(track) - 2:
         | 
| 459 | 
            +
                                cv2.arrowedLine(
         | 
| 460 | 
            +
                                    transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length
         | 
| 461 | 
            +
                                )
         | 
| 462 | 
            +
                            else:
         | 
| 463 | 
            +
                                cv2.line(
         | 
| 464 | 
            +
                                    transparent_layer,
         | 
| 465 | 
            +
                                    tuple(start_point),
         | 
| 466 | 
            +
                                    tuple(end_point),
         | 
| 467 | 
            +
                                    color,
         | 
| 468 | 
            +
                                    2,
         | 
| 469 | 
            +
                                )
         | 
| 470 | 
            +
                    else:
         | 
| 471 | 
            +
                        cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
         | 
| 474 | 
            +
                trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
         | 
| 475 | 
            +
                return {tracking_points_var: tracking_points, input_image: trajectory_map}
         | 
| 476 | 
            +
             | 
| 477 |  | 
| 478 | 
            +
            def preprocess_example_image(image_path, tracking_points, drag_mode):
         | 
| 479 | 
            +
                image_pil = image2pil(image_path)
         | 
| 480 | 
            +
                raw_w, raw_h = image_pil.size
         | 
| 481 | 
            +
                resize_ratio = max(384 / raw_w, 256 / raw_h)
         | 
| 482 | 
            +
                image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
         | 
| 483 | 
            +
                image_pil = transforms.CenterCrop((256, 384))(image_pil.convert("RGB"))
         | 
| 484 | 
            +
                id = str(uuid.uuid4())[:4]
         | 
| 485 | 
            +
                first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
         | 
| 486 | 
            +
                image_pil.save(first_frame_path, quality=95)
         | 
| 487 |  | 
| 488 | 
            +
                if drag_mode == "object":
         | 
|  | |
| 489 | 
             
                    color = (255, 0, 0, 255)
         | 
| 490 | 
            +
                elif drag_mode == "camera":
         | 
| 491 | 
             
                    color = (0, 0, 255, 255)
         | 
| 492 |  | 
| 493 | 
            +
                transparent_background = Image.open(first_frame_path).convert("RGBA")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 494 | 
             
                w, h = transparent_background.size
         | 
| 495 | 
             
                transparent_layer = np.zeros((h, w, 4))
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                for track in tracking_points:
         | 
| 498 | 
             
                    if len(track) > 1:
         | 
| 499 | 
            +
                        for i in range(len(track) - 1):
         | 
| 500 | 
             
                            start_point = track[i]
         | 
| 501 | 
            +
                            end_point = track[i + 1]
         | 
| 502 | 
             
                            vx = end_point[0] - start_point[0]
         | 
| 503 | 
             
                            vy = end_point[1] - start_point[1]
         | 
| 504 | 
             
                            arrow_length = np.sqrt(vx**2 + vy**2)
         | 
| 505 | 
            +
                            if i == len(track) - 2:
         | 
| 506 | 
            +
                                cv2.arrowedLine(
         | 
| 507 | 
            +
                                    transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length
         | 
| 508 | 
            +
                                )
         | 
| 509 | 
             
                            else:
         | 
| 510 | 
            +
                                cv2.line(
         | 
| 511 | 
            +
                                    transparent_layer,
         | 
| 512 | 
            +
                                    tuple(start_point),
         | 
| 513 | 
            +
                                    tuple(end_point),
         | 
| 514 | 
            +
                                    color,
         | 
| 515 | 
            +
                                    2,
         | 
| 516 | 
            +
                                )
         | 
| 517 | 
             
                    else:
         | 
| 518 | 
             
                        cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
         | 
| 519 |  | 
| 520 | 
             
                transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
         | 
| 521 | 
             
                trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                return trajectory_map, first_frame_path
         | 
| 524 |  | 
| 525 |  | 
| 526 | 
             
            def add_drag(tracking_points):
         | 
| 527 | 
            +
                if not tracking_points or tracking_points[-1]:
         | 
|  | |
|  | |
|  | |
|  | |
| 528 | 
             
                    tracking_points.append([])
         | 
| 529 | 
             
                return {tracking_points_var: tracking_points}
         | 
| 530 | 
            +
             | 
| 531 |  | 
| 532 | 
             
            def delete_last_drag(tracking_points, first_frame_path, drag_mode):
         | 
| 533 | 
            +
                if drag_mode == "object":
         | 
| 534 | 
             
                    color = (255, 0, 0, 255)
         | 
| 535 | 
            +
                elif drag_mode == "camera":
         | 
| 536 | 
             
                    color = (0, 0, 255, 255)
         | 
| 537 | 
            +
                if tracking_points:
         | 
| 538 | 
            +
                    tracking_points.pop()
         | 
| 539 | 
            +
                transparent_background = Image.open(first_frame_path).convert("RGBA")
         | 
| 540 | 
             
                w, h = transparent_background.size
         | 
| 541 | 
             
                transparent_layer = np.zeros((h, w, 4))
         | 
| 542 | 
            +
                for track in tracking_points:
         | 
| 543 | 
             
                    if len(track) > 1:
         | 
| 544 | 
            +
                        for i in range(len(track) - 1):
         | 
| 545 | 
             
                            start_point = track[i]
         | 
| 546 | 
            +
                            end_point = track[i + 1]
         | 
| 547 | 
             
                            vx = end_point[0] - start_point[0]
         | 
| 548 | 
             
                            vy = end_point[1] - start_point[1]
         | 
| 549 | 
             
                            arrow_length = np.sqrt(vx**2 + vy**2)
         | 
| 550 | 
            +
                            if i == len(track) - 2:
         | 
| 551 | 
            +
                                cv2.arrowedLine(
         | 
| 552 | 
            +
                                    transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length
         | 
| 553 | 
            +
                                )
         | 
| 554 | 
             
                            else:
         | 
| 555 | 
            +
                                cv2.line(
         | 
| 556 | 
            +
                                    transparent_layer,
         | 
| 557 | 
            +
                                    tuple(start_point),
         | 
| 558 | 
            +
                                    tuple(end_point),
         | 
| 559 | 
            +
                                    color,
         | 
| 560 | 
            +
                                    2,
         | 
| 561 | 
            +
                                )
         | 
| 562 | 
             
                    else:
         | 
| 563 | 
             
                        cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
         | 
| 564 |  | 
| 565 | 
             
                transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
         | 
| 566 | 
             
                trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
         | 
| 567 | 
             
                return {tracking_points_var: tracking_points, input_image: trajectory_map}
         | 
| 568 | 
            +
             | 
| 569 |  | 
| 570 | 
             
            def delete_last_step(tracking_points, first_frame_path, drag_mode):
         | 
| 571 | 
            +
                if drag_mode == "object":
         | 
| 572 | 
             
                    color = (255, 0, 0, 255)
         | 
| 573 | 
            +
                elif drag_mode == "camera":
         | 
| 574 | 
             
                    color = (0, 0, 255, 255)
         | 
| 575 | 
            +
                if tracking_points and tracking_points[-1]:
         | 
| 576 | 
            +
                    tracking_points[-1].pop()
         | 
| 577 | 
            +
                transparent_background = Image.open(first_frame_path).convert("RGBA")
         | 
| 578 | 
             
                w, h = transparent_background.size
         | 
| 579 | 
             
                transparent_layer = np.zeros((h, w, 4))
         | 
| 580 | 
            +
                for track in tracking_points:
         | 
| 581 | 
            +
                    if not track:
         | 
| 582 | 
            +
                        continue
         | 
| 583 | 
             
                    if len(track) > 1:
         | 
| 584 | 
            +
                        for i in range(len(track) - 1):
         | 
| 585 | 
             
                            start_point = track[i]
         | 
| 586 | 
            +
                            end_point = track[i + 1]
         | 
| 587 | 
             
                            vx = end_point[0] - start_point[0]
         | 
| 588 | 
             
                            vy = end_point[1] - start_point[1]
         | 
| 589 | 
             
                            arrow_length = np.sqrt(vx**2 + vy**2)
         | 
| 590 | 
            +
                            if i == len(track) - 2:
         | 
| 591 | 
            +
                                cv2.arrowedLine(
         | 
| 592 | 
            +
                                    transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length
         | 
| 593 | 
            +
                                )
         | 
| 594 | 
             
                            else:
         | 
| 595 | 
            +
                                cv2.line(
         | 
| 596 | 
            +
                                    transparent_layer,
         | 
| 597 | 
            +
                                    tuple(start_point),
         | 
| 598 | 
            +
                                    tuple(end_point),
         | 
| 599 | 
            +
                                    color,
         | 
| 600 | 
            +
                                    2,
         | 
| 601 | 
            +
                                )
         | 
| 602 | 
             
                    else:
         | 
| 603 | 
            +
                        cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
         | 
| 604 |  | 
| 605 | 
             
                transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
         | 
| 606 | 
             
                trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
         | 
| 607 | 
             
                return {tracking_points_var: tracking_points, input_image: trajectory_map}
         | 
| 608 |  | 
| 609 |  | 
| 610 | 
            +
            def load_example(drag_mode, examples_type):
         | 
| 611 | 
            +
                example_image_path = IMAGE_PATH[examples_type]
         | 
| 612 | 
            +
                with open(POINTS[examples_type]) as f:
         | 
| 613 | 
            +
                    tracking_points = json.load(f)
         | 
| 614 | 
            +
                tracking_points = np.round(tracking_points).astype(int).tolist()
         | 
| 615 | 
            +
                trajectory_map, first_frame_path = preprocess_example_image(example_image_path, tracking_points, drag_mode)
         | 
| 616 | 
            +
                return {input_image: trajectory_map, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
         | 
| 617 | 
            +
             | 
| 618 | 
            +
             | 
| 619 | 
            +
            device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
         | 
| 620 | 
            +
            ImageConductor_net = ImageConductor(
         | 
| 621 | 
            +
                device=device,
         | 
| 622 | 
            +
                unet_path=unet_path,
         | 
| 623 | 
            +
                image_controlnet_path=image_controlnet_path,
         | 
| 624 | 
            +
                flow_controlnet_path=flow_controlnet_path,
         | 
| 625 | 
            +
                height=256,
         | 
| 626 | 
            +
                width=384,
         | 
| 627 | 
            +
                model_length=16,
         | 
| 628 | 
            +
            )
         | 
| 629 | 
            +
             | 
| 630 | 
            +
            block = gr.Blocks(theme=gr.themes.Soft(radius_size=gr.themes.sizes.radius_none, text_size=gr.themes.sizes.text_md))
         | 
| 631 | 
             
            with block:
         | 
| 632 | 
             
                with gr.Row():
         | 
| 633 | 
             
                    with gr.Column():
         | 
|  | |
| 637 |  | 
| 638 | 
             
                with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"):
         | 
| 639 | 
             
                    with gr.Row(equal_height=True):
         | 
| 640 | 
            +
                        gr.Markdown(instructions)
         | 
| 641 | 
            +
             | 
| 642 | 
            +
                first_frame_path_var = gr.State()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 643 | 
             
                tracking_points_var = gr.State([])
         | 
| 644 |  | 
| 645 | 
             
                with gr.Row():
         | 
| 646 | 
             
                    with gr.Column(scale=1):
         | 
| 647 | 
            +
                        image_upload_button = gr.UploadButton(label="Upload Image", file_types=["image"])
         | 
| 648 | 
             
                        add_drag_button = gr.Button(value="Add Drag")
         | 
| 649 | 
             
                        reset_button = gr.Button(value="Reset")
         | 
| 650 | 
             
                        delete_last_drag_button = gr.Button(value="Delete last drag")
         | 
| 651 | 
             
                        delete_last_step_button = gr.Button(value="Delete last step")
         | 
|  | |
|  | |
| 652 |  | 
| 653 | 
             
                    with gr.Column(scale=7):
         | 
| 654 | 
             
                        with gr.Row():
         | 
| 655 | 
             
                            with gr.Column(scale=6):
         | 
| 656 | 
            +
                                input_image = gr.Image(
         | 
| 657 | 
            +
                                    label="Input Image",
         | 
| 658 | 
            +
                                    interactive=True,
         | 
| 659 | 
            +
                                    height=300,
         | 
| 660 | 
            +
                                    width=384,
         | 
| 661 | 
            +
                                )
         | 
| 662 | 
             
                            with gr.Column(scale=6):
         | 
| 663 | 
            +
                                output_image = gr.Image(
         | 
| 664 | 
            +
                                    label="Motion Path",
         | 
| 665 | 
            +
                                    interactive=False,
         | 
| 666 | 
            +
                                    height=256,
         | 
| 667 | 
            +
                                    width=384,
         | 
| 668 | 
            +
                                )
         | 
| 669 | 
             
                with gr.Row():
         | 
| 670 | 
             
                    with gr.Column(scale=1):
         | 
| 671 | 
            +
                        prompt = gr.Textbox(
         | 
| 672 | 
            +
                            value="a wonderful elf.",
         | 
| 673 | 
            +
                            label="Prompt (highly-recommended)",
         | 
| 674 | 
            +
                            interactive=True,
         | 
| 675 | 
            +
                            visible=True,
         | 
| 676 | 
            +
                        )
         | 
| 677 | 
             
                        negative_prompt = gr.Text(
         | 
| 678 | 
            +
                            label="Negative Prompt",
         | 
| 679 | 
            +
                            max_lines=5,
         | 
| 680 | 
            +
                            placeholder="Please input your negative prompt",
         | 
| 681 | 
            +
                            value="worst quality, low quality, letterboxed",
         | 
| 682 | 
            +
                            lines=1,
         | 
| 683 | 
            +
                        )
         | 
| 684 | 
            +
                        drag_mode = gr.Radio(["camera", "object"], label="Drag mode: ", value="object", scale=2)
         | 
| 685 | 
             
                        run_button = gr.Button(value="Run")
         | 
| 686 |  | 
| 687 | 
             
                        with gr.Accordion("More input params", open=False, elem_id="accordion1"):
         | 
| 688 | 
             
                            with gr.Group():
         | 
| 689 | 
            +
                                seed = gr.Textbox(label="Seed: ", value=561793204)
         | 
|  | |
|  | |
| 690 | 
             
                                randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
         | 
| 691 | 
            +
             | 
| 692 | 
             
                            with gr.Group():
         | 
| 693 | 
             
                                with gr.Row():
         | 
| 694 | 
             
                                    guidance_scale = gr.Slider(
         | 
|  | |
| 705 | 
             
                                        step=1,
         | 
| 706 | 
             
                                        value=25,
         | 
| 707 | 
             
                                    )
         | 
| 708 | 
            +
             | 
| 709 | 
             
                            with gr.Group():
         | 
| 710 | 
            +
                                personalized = gr.Dropdown(label="Personalized", choices=["", "HelloObject", "TUSUN"], value="")
         | 
| 711 | 
            +
                                examples_type = gr.Textbox(label="Examples Type (Ignore) ", value="", visible=False)
         | 
| 712 |  | 
| 713 | 
             
                    with gr.Column(scale=7):
         | 
| 714 | 
            +
                        output_video = gr.Video(label="Output Video", width=384, height=256)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 715 |  | 
| 716 | 
            +
                with gr.Row():
         | 
| 717 | 
             
                    example = gr.Examples(
         | 
| 718 | 
             
                        label="Input Example",
         | 
| 719 | 
             
                        examples=image_examples,
         | 
|  | |
| 721 | 
             
                        examples_per_page=10,
         | 
| 722 | 
             
                        cache_examples=False,
         | 
| 723 | 
             
                    )
         | 
| 724 | 
            +
             | 
|  | |
| 725 | 
             
                with gr.Row():
         | 
| 726 | 
             
                    gr.Markdown(citation)
         | 
| 727 |  | 
| 728 | 
            +
                image_upload_button.upload(
         | 
| 729 | 
            +
                    preprocess_image,
         | 
| 730 | 
            +
                    [image_upload_button, tracking_points_var],
         | 
| 731 | 
            +
                    [input_image, first_frame_path_var, tracking_points_var, personalized],
         | 
| 732 | 
            +
                )
         | 
| 733 |  | 
| 734 | 
             
                add_drag_button.click(add_drag, tracking_points_var, tracking_points_var)
         | 
| 735 |  | 
| 736 | 
            +
                delete_last_drag_button.click(
         | 
| 737 | 
            +
                    delete_last_drag,
         | 
| 738 | 
            +
                    [tracking_points_var, first_frame_path_var, drag_mode],
         | 
| 739 | 
            +
                    [tracking_points_var, input_image],
         | 
| 740 | 
            +
                )
         | 
| 741 | 
            +
             | 
| 742 | 
            +
                delete_last_step_button.click(
         | 
| 743 | 
            +
                    delete_last_step,
         | 
| 744 | 
            +
                    [tracking_points_var, first_frame_path_var, drag_mode],
         | 
| 745 | 
            +
                    [tracking_points_var, input_image],
         | 
| 746 | 
            +
                )
         | 
| 747 | 
            +
             | 
| 748 | 
            +
                reset_button.click(
         | 
| 749 | 
            +
                    reset_states,
         | 
| 750 | 
            +
                    [first_frame_path_var, tracking_points_var],
         | 
| 751 | 
            +
                    [input_image, first_frame_path_var, tracking_points_var],
         | 
| 752 | 
            +
                )
         | 
| 753 | 
            +
             | 
| 754 | 
            +
                input_image.select(
         | 
| 755 | 
            +
                    add_tracking_points,
         | 
| 756 | 
            +
                    [tracking_points_var, first_frame_path_var, drag_mode],
         | 
| 757 | 
            +
                    [tracking_points_var, input_image],
         | 
| 758 | 
            +
                )
         | 
| 759 | 
            +
             | 
| 760 | 
            +
                run_button.click(
         | 
| 761 | 
            +
                    ImageConductor_net.run,
         | 
| 762 | 
            +
                    [
         | 
| 763 | 
            +
                        first_frame_path_var,
         | 
| 764 | 
            +
                        tracking_points_var,
         | 
| 765 | 
            +
                        prompt,
         | 
| 766 | 
            +
                        drag_mode,
         | 
| 767 | 
            +
                        negative_prompt,
         | 
| 768 | 
            +
                        seed,
         | 
| 769 | 
            +
                        randomize_seed,
         | 
| 770 | 
            +
                        guidance_scale,
         | 
| 771 | 
            +
                        num_inference_steps,
         | 
| 772 | 
            +
                        personalized,
         | 
| 773 | 
            +
                    ],
         | 
| 774 | 
            +
                    [output_image, output_video],
         | 
| 775 | 
            +
                )
         | 
| 776 | 
            +
             | 
| 777 | 
            +
                examples_type.change(
         | 
| 778 | 
            +
                    fn=load_example,
         | 
| 779 | 
            +
                    inputs=[drag_mode, examples_type],
         | 
| 780 | 
            +
                    outputs=[input_image, first_frame_path_var, tracking_points_var],
         | 
| 781 | 
            +
                    api_name=False,
         | 
| 782 | 
            +
                    queue=False,
         | 
| 783 | 
            +
                )
         | 
| 784 |  | 
| 785 | 
             
            block.queue().launch()
         | 
| Binary file (6.15 kB) | 
|  | 
| Binary file (6.15 kB) | 
|  | 
| Binary file (6.61 kB) | 
|  | 
| Binary file (14.5 kB) | 
|  | 
| Binary file (16.9 kB) | 
|  | 
| Binary file (8.54 kB) | 
|  | 
| Binary file (5.89 kB) | 
|  | 
| Binary file (14.3 kB) | 
|  | 
| Binary file (13.9 kB) | 
|  | 
| Binary file (2.37 kB) | 
|  | 
| Binary file (4.84 kB) | 
|  | 
| Binary file (8.79 kB) | 
|  | 
| Binary file (2.18 kB) | 
|  | 
| Binary file (4.98 kB) | 
|  | 
| Binary file (14.8 kB) | 
|  | 
| Binary file (71.9 kB) | 
|  | 
| Binary file (1.38 kB) | 
|  | 
| Binary file (14.3 kB) | 
|  | 
| Binary file (27.1 kB) | 
|  | 
| Binary file (867 Bytes) | 
|  | 
| Binary file (3.18 kB) | 
|  | 
| Binary file (2.85 kB) | 
|  | 
| Binary file (1.6 kB) | 
|  | 
| Binary file (10.7 kB) | 
|  | 
| Binary file (10.2 kB) | 
|  | 
| Binary file (361 Bytes) | 
|  | 
| Binary file (2.1 kB) | 
|  | 
| Binary file (3.27 kB) | 
|  | 
| Binary file (5.54 kB) | 
|  | 
| Binary file (3.53 kB) | 
|  | 
| Binary file (323 Bytes) | 
|  | 
| Binary file (5.37 kB) | 
|  | 
| Binary file (22.5 kB) | 
|  | 
| Binary file (10.9 kB) | 
|  | 
| Binary file (188 Bytes) | 
|  | 
| Binary file (787 Bytes) | 
|  | 
| Binary file (2.57 kB) | 
|  | 
| Binary file (4.23 kB) | 
|  | 
| Binary file (7.78 kB) | 
|  | 
| Binary file (13.3 kB) | 
|  | 

