Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	new base
Browse files- .gitignore +3 -1
 - app.py +14 -0
 - app_init.py +145 -0
 - build-run.sh +12 -0
 - config.py +58 -0
 - device.py +12 -0
 - frontend/.eslintignore +13 -0
 - frontend/.eslintrc.cjs +30 -0
 - frontend/.gitignore +10 -0
 - frontend/.npmrc +1 -0
 - frontend/.prettierignore +13 -0
 - frontend/.prettierrc +19 -0
 - frontend/README.md +38 -0
 - frontend/package-lock.json +0 -0
 - frontend/package.json +36 -0
 - frontend/postcss.config.js +6 -0
 - frontend/src/app.css +3 -0
 - frontend/src/app.d.ts +12 -0
 - frontend/src/app.html +12 -0
 - frontend/src/lib/index.ts +1 -0
 - frontend/src/lib/types.ts +0 -0
 - frontend/src/routes/+layout.svelte +5 -0
 - frontend/src/routes/+page.svelte +160 -0
 - frontend/src/routes/+page.ts +1 -0
 - frontend/static/favicon.png +0 -0
 - frontend/svelte.config.js +19 -0
 - frontend/tailwind.config.js +8 -0
 - frontend/tsconfig.json +17 -0
 - frontend/vite.config.ts +6 -0
 - pipelines/__init__.py +0 -0
 - pipelines/controlnet.py +90 -0
 - pipelines/txt2img.py +85 -0
 - pipelines/txt2imglora.py +93 -0
 - requirements.txt +2 -2
 - run.py +5 -0
 - user_queue.py +18 -0
 - util.py +16 -0
 
    	
        .gitignore
    CHANGED
    
    | 
         @@ -1,2 +1,4 @@ 
     | 
|
| 1 | 
         
             
            __pycache__/
         
     | 
| 2 | 
         
            -
            venv/
         
     | 
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
             
            __pycache__/
         
     | 
| 2 | 
         
            +
            venv/
         
     | 
| 3 | 
         
            +
            public/
         
     | 
| 4 | 
         
            +
            *.pem
         
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,14 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from fastapi import FastAPI
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from config import args
         
     | 
| 4 | 
         
            +
            from device import device, torch_dtype
         
     | 
| 5 | 
         
            +
            from app_init import init_app
         
     | 
| 6 | 
         
            +
            from user_queue import user_queue_map
         
     | 
| 7 | 
         
            +
            from util import get_pipeline_class
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            app = FastAPI()
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            pipeline_class = get_pipeline_class(args.pipeline)
         
     | 
| 13 | 
         
            +
            pipeline = pipeline_class(args, device, torch_dtype)
         
     | 
| 14 | 
         
            +
            init_app(app, user_queue_map, args, pipeline)
         
     | 
    	
        app_init.py
    ADDED
    
    | 
         @@ -0,0 +1,145 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
         
     | 
| 2 | 
         
            +
            from fastapi.responses import StreamingResponse, JSONResponse
         
     | 
| 3 | 
         
            +
            from fastapi.middleware.cors import CORSMiddleware
         
     | 
| 4 | 
         
            +
            from fastapi.staticfiles import StaticFiles
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import logging
         
     | 
| 7 | 
         
            +
            import traceback
         
     | 
| 8 | 
         
            +
            from config import Args
         
     | 
| 9 | 
         
            +
            from user_queue import UserQueueDict
         
     | 
| 10 | 
         
            +
            import uuid
         
     | 
| 11 | 
         
            +
            import asyncio
         
     | 
| 12 | 
         
            +
            import time
         
     | 
| 13 | 
         
            +
            from PIL import Image
         
     | 
| 14 | 
         
            +
            import io
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            def init_app(app: FastAPI, user_queue_map: UserQueueDict, args: Args, pipeline):
         
     | 
| 18 | 
         
            +
                app.add_middleware(
         
     | 
| 19 | 
         
            +
                    CORSMiddleware,
         
     | 
| 20 | 
         
            +
                    allow_origins=["*"],
         
     | 
| 21 | 
         
            +
                    allow_credentials=True,
         
     | 
| 22 | 
         
            +
                    allow_methods=["*"],
         
     | 
| 23 | 
         
            +
                    allow_headers=["*"],
         
     | 
| 24 | 
         
            +
                )
         
     | 
| 25 | 
         
            +
                print("Init app", app)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                @app.websocket("/ws")
         
     | 
| 28 | 
         
            +
                async def websocket_endpoint(websocket: WebSocket):
         
     | 
| 29 | 
         
            +
                    await websocket.accept()
         
     | 
| 30 | 
         
            +
                    if args.max_queue_size > 0 and len(user_queue_map) >= args.max_queue_size:
         
     | 
| 31 | 
         
            +
                        print("Server is full")
         
     | 
| 32 | 
         
            +
                        await websocket.send_json({"status": "error", "message": "Server is full"})
         
     | 
| 33 | 
         
            +
                        await websocket.close()
         
     | 
| 34 | 
         
            +
                        return
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    try:
         
     | 
| 37 | 
         
            +
                        uid = uuid.uuid4()
         
     | 
| 38 | 
         
            +
                        print(f"New user connected: {uid}")
         
     | 
| 39 | 
         
            +
                        await websocket.send_json(
         
     | 
| 40 | 
         
            +
                            {"status": "success", "message": "Connected", "userId": uid}
         
     | 
| 41 | 
         
            +
                        )
         
     | 
| 42 | 
         
            +
                        user_queue_map[uid] = {"queue": asyncio.Queue()}
         
     | 
| 43 | 
         
            +
                        await websocket.send_json(
         
     | 
| 44 | 
         
            +
                            {"status": "start", "message": "Start Streaming", "userId": uid}
         
     | 
| 45 | 
         
            +
                        )
         
     | 
| 46 | 
         
            +
                        await handle_websocket_data(websocket, uid)
         
     | 
| 47 | 
         
            +
                    except WebSocketDisconnect as e:
         
     | 
| 48 | 
         
            +
                        logging.error(f"WebSocket Error: {e}, {uid}")
         
     | 
| 49 | 
         
            +
                        traceback.print_exc()
         
     | 
| 50 | 
         
            +
                    finally:
         
     | 
| 51 | 
         
            +
                        print(f"User disconnected: {uid}")
         
     | 
| 52 | 
         
            +
                        queue_value = user_queue_map.pop(uid, None)
         
     | 
| 53 | 
         
            +
                        queue = queue_value.get("queue", None)
         
     | 
| 54 | 
         
            +
                        if queue:
         
     | 
| 55 | 
         
            +
                            while not queue.empty():
         
     | 
| 56 | 
         
            +
                                try:
         
     | 
| 57 | 
         
            +
                                    queue.get_nowait()
         
     | 
| 58 | 
         
            +
                                except asyncio.QueueEmpty:
         
     | 
| 59 | 
         
            +
                                    continue
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                @app.get("/queue_size")
         
     | 
| 62 | 
         
            +
                async def get_queue_size():
         
     | 
| 63 | 
         
            +
                    queue_size = len(user_queue_map)
         
     | 
| 64 | 
         
            +
                    return JSONResponse({"queue_size": queue_size})
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                @app.get("/stream/{user_id}")
         
     | 
| 67 | 
         
            +
                async def stream(user_id: uuid.UUID):
         
     | 
| 68 | 
         
            +
                    uid = user_id
         
     | 
| 69 | 
         
            +
                    try:
         
     | 
| 70 | 
         
            +
                        user_queue = user_queue_map[uid]
         
     | 
| 71 | 
         
            +
                        queue = user_queue["queue"]
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                        async def generate():
         
     | 
| 74 | 
         
            +
                            last_prompt: str = None
         
     | 
| 75 | 
         
            +
                            while True:
         
     | 
| 76 | 
         
            +
                                data = await queue.get()
         
     | 
| 77 | 
         
            +
                                input_image = data["image"]
         
     | 
| 78 | 
         
            +
                                params = data["params"]
         
     | 
| 79 | 
         
            +
                                if input_image is None:
         
     | 
| 80 | 
         
            +
                                    continue
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                                image = pipeline.predict(
         
     | 
| 83 | 
         
            +
                                    input_image,
         
     | 
| 84 | 
         
            +
                                    params,
         
     | 
| 85 | 
         
            +
                                )
         
     | 
| 86 | 
         
            +
                                if image is None:
         
     | 
| 87 | 
         
            +
                                    continue
         
     | 
| 88 | 
         
            +
                                frame_data = io.BytesIO()
         
     | 
| 89 | 
         
            +
                                image.save(frame_data, format="JPEG")
         
     | 
| 90 | 
         
            +
                                frame_data = frame_data.getvalue()
         
     | 
| 91 | 
         
            +
                                if frame_data is not None and len(frame_data) > 0:
         
     | 
| 92 | 
         
            +
                                    yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                                await asyncio.sleep(1.0 / 120.0)
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                        return StreamingResponse(
         
     | 
| 97 | 
         
            +
                            generate(), media_type="multipart/x-mixed-replace;boundary=frame"
         
     | 
| 98 | 
         
            +
                        )
         
     | 
| 99 | 
         
            +
                    except Exception as e:
         
     | 
| 100 | 
         
            +
                        logging.error(f"Streaming Error: {e}, {user_queue_map}")
         
     | 
| 101 | 
         
            +
                        traceback.print_exc()
         
     | 
| 102 | 
         
            +
                        return HTTPException(status_code=404, detail="User not found")
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
         
     | 
| 105 | 
         
            +
                    uid = user_id
         
     | 
| 106 | 
         
            +
                    user_queue = user_queue_map[uid]
         
     | 
| 107 | 
         
            +
                    queue = user_queue["queue"]
         
     | 
| 108 | 
         
            +
                    if not queue:
         
     | 
| 109 | 
         
            +
                        return HTTPException(status_code=404, detail="User not found")
         
     | 
| 110 | 
         
            +
                    last_time = time.time()
         
     | 
| 111 | 
         
            +
                    try:
         
     | 
| 112 | 
         
            +
                        while True:
         
     | 
| 113 | 
         
            +
                            data = await websocket.receive_bytes()
         
     | 
| 114 | 
         
            +
                            params = await websocket.receive_json()
         
     | 
| 115 | 
         
            +
                            params = pipeline.InputParams(**params)
         
     | 
| 116 | 
         
            +
                            pil_image = Image.open(io.BytesIO(data))
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                            while not queue.empty():
         
     | 
| 119 | 
         
            +
                                try:
         
     | 
| 120 | 
         
            +
                                    queue.get_nowait()
         
     | 
| 121 | 
         
            +
                                except asyncio.QueueEmpty:
         
     | 
| 122 | 
         
            +
                                    continue
         
     | 
| 123 | 
         
            +
                            await queue.put({"image": pil_image, "params": params})
         
     | 
| 124 | 
         
            +
                            if args.timeout > 0 and time.time() - last_time > args.timeout:
         
     | 
| 125 | 
         
            +
                                await websocket.send_json(
         
     | 
| 126 | 
         
            +
                                    {
         
     | 
| 127 | 
         
            +
                                        "status": "timeout",
         
     | 
| 128 | 
         
            +
                                        "message": "Your session has ended",
         
     | 
| 129 | 
         
            +
                                        "userId": uid,
         
     | 
| 130 | 
         
            +
                                    }
         
     | 
| 131 | 
         
            +
                                )
         
     | 
| 132 | 
         
            +
                                await websocket.close()
         
     | 
| 133 | 
         
            +
                                return
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    except Exception as e:
         
     | 
| 136 | 
         
            +
                        logging.error(f"Error: {e}")
         
     | 
| 137 | 
         
            +
                        traceback.print_exc()
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                # route to setup frontend
         
     | 
| 140 | 
         
            +
                @app.get("/settings")
         
     | 
| 141 | 
         
            +
                async def settings():
         
     | 
| 142 | 
         
            +
                    params = pipeline.InputParams()
         
     | 
| 143 | 
         
            +
                    return JSONResponse({"settings": params.dict()})
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                app.mount("/", StaticFiles(directory="public", html=True), name="public")
         
     | 
    	
        build-run.sh
    ADDED
    
    | 
         @@ -0,0 +1,12 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/bin/bash
         
     | 
| 2 | 
         
            +
            cd frontend
         
     | 
| 3 | 
         
            +
            npm install
         
     | 
| 4 | 
         
            +
            npm run build
         
     | 
| 5 | 
         
            +
            if [ $? -eq 0 ]; then
         
     | 
| 6 | 
         
            +
                echo -e "\033[1;32m\nfrontend build success \033[0m"
         
     | 
| 7 | 
         
            +
            else
         
     | 
| 8 | 
         
            +
                echo -e "\033[1;31m\nfrontend build failed\n\033[0m" >&2  exit 1
         
     | 
| 9 | 
         
            +
            fi
         
     | 
| 10 | 
         
            +
            cd ../
         
     | 
| 11 | 
         
            +
            python run.py --reload
         
     | 
| 12 | 
         
            +
             
     | 
    	
        config.py
    ADDED
    
    | 
         @@ -0,0 +1,58 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import NamedTuple
         
     | 
| 2 | 
         
            +
            import argparse
         
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            class Args(NamedTuple):
         
     | 
| 7 | 
         
            +
                host: str
         
     | 
| 8 | 
         
            +
                port: int
         
     | 
| 9 | 
         
            +
                reload: bool
         
     | 
| 10 | 
         
            +
                mode: str
         
     | 
| 11 | 
         
            +
                max_queue_size: int
         
     | 
| 12 | 
         
            +
                timeout: float
         
     | 
| 13 | 
         
            +
                safety_checker: bool
         
     | 
| 14 | 
         
            +
                torch_compile: bool
         
     | 
| 15 | 
         
            +
                use_taesd: bool
         
     | 
| 16 | 
         
            +
                pipeline: str
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
         
     | 
| 20 | 
         
            +
            TIMEOUT = float(os.environ.get("TIMEOUT", 0))
         
     | 
| 21 | 
         
            +
            SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None) == "True"
         
     | 
| 22 | 
         
            +
            TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None) == "True"
         
     | 
| 23 | 
         
            +
            USE_TAESD = os.environ.get("USE_TAESD", None) == "True"
         
     | 
| 24 | 
         
            +
            default_host = os.getenv("HOST", "0.0.0.0")
         
     | 
| 25 | 
         
            +
            default_port = int(os.getenv("PORT", "7860"))
         
     | 
| 26 | 
         
            +
            default_mode = os.getenv("MODE", "default")
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            parser = argparse.ArgumentParser(description="Run the app")
         
     | 
| 29 | 
         
            +
            parser.add_argument("--host", type=str, default=default_host, help="Host address")
         
     | 
| 30 | 
         
            +
            parser.add_argument("--port", type=int, default=default_port, help="Port number")
         
     | 
| 31 | 
         
            +
            parser.add_argument("--reload", action="store_true", help="Reload code on change")
         
     | 
| 32 | 
         
            +
            parser.add_argument(
         
     | 
| 33 | 
         
            +
                "--mode", type=str, default=default_mode, help="App Inferece Mode: txt2img, img2img"
         
     | 
| 34 | 
         
            +
            )
         
     | 
| 35 | 
         
            +
            parser.add_argument(
         
     | 
| 36 | 
         
            +
                "--max_queue_size", type=int, default=MAX_QUEUE_SIZE, help="Max Queue Size"
         
     | 
| 37 | 
         
            +
            )
         
     | 
| 38 | 
         
            +
            parser.add_argument("--timeout", type=float, default=TIMEOUT, help="Timeout")
         
     | 
| 39 | 
         
            +
            parser.add_argument(
         
     | 
| 40 | 
         
            +
                "--safety_checker", type=bool, default=SAFETY_CHECKER, help="Safety Checker"
         
     | 
| 41 | 
         
            +
            )
         
     | 
| 42 | 
         
            +
            parser.add_argument(
         
     | 
| 43 | 
         
            +
                "--torch_compile", type=bool, default=TORCH_COMPILE, help="Torch Compile"
         
     | 
| 44 | 
         
            +
            )
         
     | 
| 45 | 
         
            +
            parser.add_argument(
         
     | 
| 46 | 
         
            +
                "--use_taesd",
         
     | 
| 47 | 
         
            +
                type=bool,
         
     | 
| 48 | 
         
            +
                default=USE_TAESD,
         
     | 
| 49 | 
         
            +
                help="Use Tiny Autoencoder",
         
     | 
| 50 | 
         
            +
            )
         
     | 
| 51 | 
         
            +
            parser.add_argument(
         
     | 
| 52 | 
         
            +
                "--pipeline",
         
     | 
| 53 | 
         
            +
                type=str,
         
     | 
| 54 | 
         
            +
                default="txt2img",
         
     | 
| 55 | 
         
            +
                help="Pipeline to use",
         
     | 
| 56 | 
         
            +
            )
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            args = Args(**vars(parser.parse_args()))
         
     | 
    	
        device.py
    ADDED
    
    | 
         @@ -0,0 +1,12 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            # check if MPS is available OSX only M1/M2/M3 chips
         
     | 
| 4 | 
         
            +
            mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
         
     | 
| 5 | 
         
            +
            xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
         
     | 
| 6 | 
         
            +
            device = torch.device(
         
     | 
| 7 | 
         
            +
                "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
         
     | 
| 8 | 
         
            +
            )
         
     | 
| 9 | 
         
            +
            torch_dtype = torch.float16
         
     | 
| 10 | 
         
            +
            if mps_available:
         
     | 
| 11 | 
         
            +
                device = torch.device("mps")
         
     | 
| 12 | 
         
            +
                torch_dtype = torch.float32
         
     | 
    	
        frontend/.eslintignore
    ADDED
    
    | 
         @@ -0,0 +1,13 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            .DS_Store
         
     | 
| 2 | 
         
            +
            node_modules
         
     | 
| 3 | 
         
            +
            /build
         
     | 
| 4 | 
         
            +
            /.svelte-kit
         
     | 
| 5 | 
         
            +
            /package
         
     | 
| 6 | 
         
            +
            .env
         
     | 
| 7 | 
         
            +
            .env.*
         
     | 
| 8 | 
         
            +
            !.env.example
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # Ignore files for PNPM, NPM and YARN
         
     | 
| 11 | 
         
            +
            pnpm-lock.yaml
         
     | 
| 12 | 
         
            +
            package-lock.json
         
     | 
| 13 | 
         
            +
            yarn.lock
         
     | 
    	
        frontend/.eslintrc.cjs
    ADDED
    
    | 
         @@ -0,0 +1,30 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            module.exports = {
         
     | 
| 2 | 
         
            +
              root: true,
         
     | 
| 3 | 
         
            +
              extends: [
         
     | 
| 4 | 
         
            +
                'eslint:recommended',
         
     | 
| 5 | 
         
            +
                'plugin:@typescript-eslint/recommended',
         
     | 
| 6 | 
         
            +
                'plugin:svelte/recommended',
         
     | 
| 7 | 
         
            +
                'prettier'
         
     | 
| 8 | 
         
            +
              ],
         
     | 
| 9 | 
         
            +
              parser: '@typescript-eslint/parser',
         
     | 
| 10 | 
         
            +
              plugins: ['@typescript-eslint'],
         
     | 
| 11 | 
         
            +
              parserOptions: {
         
     | 
| 12 | 
         
            +
                sourceType: 'module',
         
     | 
| 13 | 
         
            +
                ecmaVersion: 2020,
         
     | 
| 14 | 
         
            +
                extraFileExtensions: ['.svelte']
         
     | 
| 15 | 
         
            +
              },
         
     | 
| 16 | 
         
            +
              env: {
         
     | 
| 17 | 
         
            +
                browser: true,
         
     | 
| 18 | 
         
            +
                es2017: true,
         
     | 
| 19 | 
         
            +
                node: true
         
     | 
| 20 | 
         
            +
              },
         
     | 
| 21 | 
         
            +
              overrides: [
         
     | 
| 22 | 
         
            +
                {
         
     | 
| 23 | 
         
            +
                  files: ['*.svelte'],
         
     | 
| 24 | 
         
            +
                  parser: 'svelte-eslint-parser',
         
     | 
| 25 | 
         
            +
                  parserOptions: {
         
     | 
| 26 | 
         
            +
                    parser: '@typescript-eslint/parser'
         
     | 
| 27 | 
         
            +
                  }
         
     | 
| 28 | 
         
            +
                }
         
     | 
| 29 | 
         
            +
              ]
         
     | 
| 30 | 
         
            +
            };
         
     | 
    	
        frontend/.gitignore
    ADDED
    
    | 
         @@ -0,0 +1,10 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            .DS_Store
         
     | 
| 2 | 
         
            +
            node_modules
         
     | 
| 3 | 
         
            +
            /build
         
     | 
| 4 | 
         
            +
            /.svelte-kit
         
     | 
| 5 | 
         
            +
            /package
         
     | 
| 6 | 
         
            +
            .env
         
     | 
| 7 | 
         
            +
            .env.*
         
     | 
| 8 | 
         
            +
            !.env.example
         
     | 
| 9 | 
         
            +
            vite.config.js.timestamp-*
         
     | 
| 10 | 
         
            +
            vite.config.ts.timestamp-*
         
     | 
    	
        frontend/.npmrc
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            engine-strict=true
         
     | 
    	
        frontend/.prettierignore
    ADDED
    
    | 
         @@ -0,0 +1,13 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            .DS_Store
         
     | 
| 2 | 
         
            +
            node_modules
         
     | 
| 3 | 
         
            +
            /build
         
     | 
| 4 | 
         
            +
            /.svelte-kit
         
     | 
| 5 | 
         
            +
            /package
         
     | 
| 6 | 
         
            +
            .env
         
     | 
| 7 | 
         
            +
            .env.*
         
     | 
| 8 | 
         
            +
            !.env.example
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # Ignore files for PNPM, NPM and YARN
         
     | 
| 11 | 
         
            +
            pnpm-lock.yaml
         
     | 
| 12 | 
         
            +
            package-lock.json
         
     | 
| 13 | 
         
            +
            yarn.lock
         
     | 
    	
        frontend/.prettierrc
    ADDED
    
    | 
         @@ -0,0 +1,19 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "useTabs": false,
         
     | 
| 3 | 
         
            +
              "singleQuote": true,
         
     | 
| 4 | 
         
            +
              "trailingComma": "none",
         
     | 
| 5 | 
         
            +
              "printWidth": 100,
         
     | 
| 6 | 
         
            +
              "plugins": [
         
     | 
| 7 | 
         
            +
                "prettier-plugin-svelte",
         
     | 
| 8 | 
         
            +
                "prettier-plugin-organize-imports",
         
     | 
| 9 | 
         
            +
                "prettier-plugin-tailwindcss"
         
     | 
| 10 | 
         
            +
              ],
         
     | 
| 11 | 
         
            +
              "overrides": [
         
     | 
| 12 | 
         
            +
                {
         
     | 
| 13 | 
         
            +
                  "files": "*.svelte",
         
     | 
| 14 | 
         
            +
                  "options": {
         
     | 
| 15 | 
         
            +
                    "parser": "svelte"
         
     | 
| 16 | 
         
            +
                  }
         
     | 
| 17 | 
         
            +
                }
         
     | 
| 18 | 
         
            +
              ]
         
     | 
| 19 | 
         
            +
            }
         
     | 
    	
        frontend/README.md
    ADDED
    
    | 
         @@ -0,0 +1,38 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # create-svelte
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            Everything you need to build a Svelte project, powered by [`create-svelte`](https://github.com/sveltejs/kit/tree/master/packages/create-svelte).
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            ## Creating a project
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            If you're seeing this, you've probably already done this step. Congrats!
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            ```bash
         
     | 
| 10 | 
         
            +
            # create a new project in the current directory
         
     | 
| 11 | 
         
            +
            npm create svelte@latest
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            # create a new project in my-app
         
     | 
| 14 | 
         
            +
            npm create svelte@latest my-app
         
     | 
| 15 | 
         
            +
            ```
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            ## Developing
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            Once you've created a project and installed dependencies with `npm install` (or `pnpm install` or `yarn`), start a development server:
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            ```bash
         
     | 
| 22 | 
         
            +
            npm run dev
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            # or start the server and open the app in a new browser tab
         
     | 
| 25 | 
         
            +
            npm run dev -- --open
         
     | 
| 26 | 
         
            +
            ```
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            ## Building
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            To create a production version of your app:
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            ```bash
         
     | 
| 33 | 
         
            +
            npm run build
         
     | 
| 34 | 
         
            +
            ```
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            You can preview the production build with `npm run preview`.
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            > To deploy your app, you may need to install an [adapter](https://kit.svelte.dev/docs/adapters) for your target environment.
         
     | 
    	
        frontend/package-lock.json
    ADDED
    
    | 
         The diff for this file is too large to render. 
		See raw diff 
     | 
| 
         | 
    	
        frontend/package.json
    ADDED
    
    | 
         @@ -0,0 +1,36 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "name": "frontend",
         
     | 
| 3 | 
         
            +
              "version": "0.0.1",
         
     | 
| 4 | 
         
            +
              "private": true,
         
     | 
| 5 | 
         
            +
              "scripts": {
         
     | 
| 6 | 
         
            +
                "dev": "vite dev",
         
     | 
| 7 | 
         
            +
                "build": "vite build",
         
     | 
| 8 | 
         
            +
                "preview": "vite preview",
         
     | 
| 9 | 
         
            +
                "check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json",
         
     | 
| 10 | 
         
            +
                "check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch",
         
     | 
| 11 | 
         
            +
                "lint": "prettier --check . && eslint .",
         
     | 
| 12 | 
         
            +
                "format": "prettier --write ."
         
     | 
| 13 | 
         
            +
              },
         
     | 
| 14 | 
         
            +
              "devDependencies": {
         
     | 
| 15 | 
         
            +
                "@sveltejs/adapter-auto": "^2.0.0",
         
     | 
| 16 | 
         
            +
                "@sveltejs/kit": "^1.20.4",
         
     | 
| 17 | 
         
            +
                "@typescript-eslint/eslint-plugin": "^6.0.0",
         
     | 
| 18 | 
         
            +
                "@typescript-eslint/parser": "^6.0.0",
         
     | 
| 19 | 
         
            +
                "autoprefixer": "^10.4.16",
         
     | 
| 20 | 
         
            +
                "eslint": "^8.28.0",
         
     | 
| 21 | 
         
            +
                "eslint-config-prettier": "^9.0.0",
         
     | 
| 22 | 
         
            +
                "eslint-plugin-svelte": "^2.30.0",
         
     | 
| 23 | 
         
            +
                "postcss": "^8.4.31",
         
     | 
| 24 | 
         
            +
                "prettier": "^3.1.0",
         
     | 
| 25 | 
         
            +
                "prettier-plugin-organize-imports": "^3.2.4",
         
     | 
| 26 | 
         
            +
                "prettier-plugin-svelte": "^3.1.0",
         
     | 
| 27 | 
         
            +
                "prettier-plugin-tailwindcss": "^0.5.7",
         
     | 
| 28 | 
         
            +
                "svelte": "^4.0.5",
         
     | 
| 29 | 
         
            +
                "svelte-check": "^3.4.3",
         
     | 
| 30 | 
         
            +
                "tailwindcss": "^3.3.5",
         
     | 
| 31 | 
         
            +
                "tslib": "^2.4.1",
         
     | 
| 32 | 
         
            +
                "typescript": "^5.0.0",
         
     | 
| 33 | 
         
            +
                "vite": "^4.4.2"
         
     | 
| 34 | 
         
            +
              },
         
     | 
| 35 | 
         
            +
              "type": "module"
         
     | 
| 36 | 
         
            +
            }
         
     | 
    	
        frontend/postcss.config.js
    ADDED
    
    | 
         @@ -0,0 +1,6 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            export default {
         
     | 
| 2 | 
         
            +
              plugins: {
         
     | 
| 3 | 
         
            +
                tailwindcss: {},
         
     | 
| 4 | 
         
            +
                autoprefixer: {}
         
     | 
| 5 | 
         
            +
              }
         
     | 
| 6 | 
         
            +
            };
         
     | 
    	
        frontend/src/app.css
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            @tailwind base;
         
     | 
| 2 | 
         
            +
            @tailwind components;
         
     | 
| 3 | 
         
            +
            @tailwind utilities;
         
     | 
    	
        frontend/src/app.d.ts
    ADDED
    
    | 
         @@ -0,0 +1,12 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            // See https://kit.svelte.dev/docs/types#app
         
     | 
| 2 | 
         
            +
            // for information about these interfaces
         
     | 
| 3 | 
         
            +
            declare global {
         
     | 
| 4 | 
         
            +
              namespace App {
         
     | 
| 5 | 
         
            +
                // interface Error {}
         
     | 
| 6 | 
         
            +
                // interface Locals {}
         
     | 
| 7 | 
         
            +
                // interface PageData {}
         
     | 
| 8 | 
         
            +
                // interface Platform {}
         
     | 
| 9 | 
         
            +
              }
         
     | 
| 10 | 
         
            +
            }
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            export {};
         
     | 
    	
        frontend/src/app.html
    ADDED
    
    | 
         @@ -0,0 +1,12 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            <!doctype html>
         
     | 
| 2 | 
         
            +
            <html lang="en">
         
     | 
| 3 | 
         
            +
              <head>
         
     | 
| 4 | 
         
            +
                <meta charset="utf-8" />
         
     | 
| 5 | 
         
            +
                <link rel="icon" href="%sveltekit.assets%/favicon.png" />
         
     | 
| 6 | 
         
            +
                <meta name="viewport" content="width=device-width, initial-scale=1" />
         
     | 
| 7 | 
         
            +
                %sveltekit.head%
         
     | 
| 8 | 
         
            +
              </head>
         
     | 
| 9 | 
         
            +
              <body data-sveltekit-preload-data="hover">
         
     | 
| 10 | 
         
            +
                <div style="display: contents">%sveltekit.body%</div>
         
     | 
| 11 | 
         
            +
              </body>
         
     | 
| 12 | 
         
            +
            </html>
         
     | 
    	
        frontend/src/lib/index.ts
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            // place files you want to import through the `$lib` alias in this folder.
         
     | 
    	
        frontend/src/lib/types.ts
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        frontend/src/routes/+layout.svelte
    ADDED
    
    | 
         @@ -0,0 +1,5 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            <script>
         
     | 
| 2 | 
         
            +
              import '../app.css';
         
     | 
| 3 | 
         
            +
            </script>
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            <slot />
         
     | 
    	
        frontend/src/routes/+page.svelte
    ADDED
    
    | 
         @@ -0,0 +1,160 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            <script lang="ts">
         
     | 
| 2 | 
         
            +
              import { onMount } from 'svelte';
         
     | 
| 3 | 
         
            +
              import { PUBLIC_BASE_URL } from '$env/static/public';
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
              onMount(() => {
         
     | 
| 6 | 
         
            +
                getSettings();
         
     | 
| 7 | 
         
            +
              });
         
     | 
| 8 | 
         
            +
              async function getSettings() {
         
     | 
| 9 | 
         
            +
                const settings = await fetch(`${PUBLIC_BASE_URL}/settings`).then((r) => r.json());
         
     | 
| 10 | 
         
            +
                console.log(settings);
         
     | 
| 11 | 
         
            +
              }
         
     | 
| 12 | 
         
            +
            </script>
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            <div class="fixed right-2 top-2 max-w-xs rounded-lg p-4 text-center text-sm font-bold" id="error" />
         
     | 
| 15 | 
         
            +
            <main class="container mx-auto flex max-w-4xl flex-col gap-4 px-4 py-4">
         
     | 
| 16 | 
         
            +
              <article class="mx-auto max-w-xl text-center">
         
     | 
| 17 | 
         
            +
                <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model</h1>
         
     | 
| 18 | 
         
            +
                <h2 class="mb-4 text-2xl font-bold">Image to Image</h2>
         
     | 
| 19 | 
         
            +
                <p class="text-sm">
         
     | 
| 20 | 
         
            +
                  This demo showcases
         
     | 
| 21 | 
         
            +
                  <a
         
     | 
| 22 | 
         
            +
                    href="https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7"
         
     | 
| 23 | 
         
            +
                    target="_blank"
         
     | 
| 24 | 
         
            +
                    class="text-blue-500 underline hover:no-underline">LCM</a
         
     | 
| 25 | 
         
            +
                  >
         
     | 
| 26 | 
         
            +
                  Image to Image pipeline using
         
     | 
| 27 | 
         
            +
                  <a
         
     | 
| 28 | 
         
            +
                    href="https://github.com/huggingface/diffusers/tree/main/examples/community#latent-consistency-pipeline"
         
     | 
| 29 | 
         
            +
                    target="_blank"
         
     | 
| 30 | 
         
            +
                    class="text-blue-500 underline hover:no-underline">Diffusers</a
         
     | 
| 31 | 
         
            +
                  > with a MJPEG stream server.
         
     | 
| 32 | 
         
            +
                </p>
         
     | 
| 33 | 
         
            +
                <p class="text-sm">
         
     | 
| 34 | 
         
            +
                  There are <span id="queue_size" class="font-bold">0</span> user(s) sharing the same GPU,
         
     | 
| 35 | 
         
            +
                  affecting real-time performance. Maximum queue size is 4.
         
     | 
| 36 | 
         
            +
                  <a
         
     | 
| 37 | 
         
            +
                    href="https://huggingface.co/spaces/radames/Real-Time-Latent-Consistency-Model?duplicate=true"
         
     | 
| 38 | 
         
            +
                    target="_blank"
         
     | 
| 39 | 
         
            +
                    class="text-blue-500 underline hover:no-underline">Duplicate</a
         
     | 
| 40 | 
         
            +
                  > and run it on your own GPU.
         
     | 
| 41 | 
         
            +
                </p>
         
     | 
| 42 | 
         
            +
              </article>
         
     | 
| 43 | 
         
            +
              <div>
         
     | 
| 44 | 
         
            +
                <h2 class="font-medium">Prompt</h2>
         
     | 
| 45 | 
         
            +
                <p class="text-sm text-gray-500">
         
     | 
| 46 | 
         
            +
                  Change the prompt to generate different images, accepts <a
         
     | 
| 47 | 
         
            +
                    href="https://github.com/damian0815/compel/blob/main/doc/syntax.md"
         
     | 
| 48 | 
         
            +
                    target="_blank"
         
     | 
| 49 | 
         
            +
                    class="text-blue-500 underline hover:no-underline">Compel</a
         
     | 
| 50 | 
         
            +
                  > syntax.
         
     | 
| 51 | 
         
            +
                </p>
         
     | 
| 52 | 
         
            +
                <div class="text-normal flex items-center rounded-md border border-gray-700 px-1 py-1">
         
     | 
| 53 | 
         
            +
                  <textarea
         
     | 
| 54 | 
         
            +
                    type="text"
         
     | 
| 55 | 
         
            +
                    id="prompt"
         
     | 
| 56 | 
         
            +
                    class="mx-1 w-full px-3 py-2 font-light outline-none dark:text-black"
         
     | 
| 57 | 
         
            +
                    title="Prompt, this is an example, feel free to modify"
         
     | 
| 58 | 
         
            +
                    placeholder="Add your prompt here..."
         
     | 
| 59 | 
         
            +
                    >Portrait of The Terminator with , glare pose, detailed, intricate, full of colour,
         
     | 
| 60 | 
         
            +
                    cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details,
         
     | 
| 61 | 
         
            +
                    unreal engine 5, cinematic, masterpiece</textarea
         
     | 
| 62 | 
         
            +
                  >
         
     | 
| 63 | 
         
            +
                </div>
         
     | 
| 64 | 
         
            +
              </div>
         
     | 
| 65 | 
         
            +
              <div class="">
         
     | 
| 66 | 
         
            +
                <details>
         
     | 
| 67 | 
         
            +
                  <summary class="cursor-pointer font-medium">Advanced Options</summary>
         
     | 
| 68 | 
         
            +
                  <div class="grid max-w-md grid-cols-3 items-center gap-3 py-3">
         
     | 
| 69 | 
         
            +
                    <label class="text-sm font-medium" for="guidance-scale">Guidance Scale </label>
         
     | 
| 70 | 
         
            +
                    <input
         
     | 
| 71 | 
         
            +
                      type="range"
         
     | 
| 72 | 
         
            +
                      id="guidance-scale"
         
     | 
| 73 | 
         
            +
                      name="guidance-scale"
         
     | 
| 74 | 
         
            +
                      min="1"
         
     | 
| 75 | 
         
            +
                      max="30"
         
     | 
| 76 | 
         
            +
                      step="0.001"
         
     | 
| 77 | 
         
            +
                      value="8.0"
         
     | 
| 78 | 
         
            +
                      oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
         
     | 
| 79 | 
         
            +
                    />
         
     | 
| 80 | 
         
            +
                    <output
         
     | 
| 81 | 
         
            +
                      class="w-[50px] rounded-md border border-gray-700 px-1 py-1 text-center text-xs font-light"
         
     | 
| 82 | 
         
            +
                    >
         
     | 
| 83 | 
         
            +
                      8.0</output
         
     | 
| 84 | 
         
            +
                    >
         
     | 
| 85 | 
         
            +
                    <label class="text-sm font-medium" for="strength">Strength</label>
         
     | 
| 86 | 
         
            +
                    <input
         
     | 
| 87 | 
         
            +
                      type="range"
         
     | 
| 88 | 
         
            +
                      id="strength"
         
     | 
| 89 | 
         
            +
                      name="strength"
         
     | 
| 90 | 
         
            +
                      min="0.20"
         
     | 
| 91 | 
         
            +
                      max="1"
         
     | 
| 92 | 
         
            +
                      step="0.001"
         
     | 
| 93 | 
         
            +
                      value="0.50"
         
     | 
| 94 | 
         
            +
                      oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
         
     | 
| 95 | 
         
            +
                    />
         
     | 
| 96 | 
         
            +
                    <output
         
     | 
| 97 | 
         
            +
                      class="w-[50px] rounded-md border border-gray-700 px-1 py-1 text-center text-xs font-light"
         
     | 
| 98 | 
         
            +
                    >
         
     | 
| 99 | 
         
            +
                      0.5</output
         
     | 
| 100 | 
         
            +
                    >
         
     | 
| 101 | 
         
            +
                    <label class="text-sm font-medium" for="seed">Seed</label>
         
     | 
| 102 | 
         
            +
                    <input
         
     | 
| 103 | 
         
            +
                      type="number"
         
     | 
| 104 | 
         
            +
                      id="seed"
         
     | 
| 105 | 
         
            +
                      name="seed"
         
     | 
| 106 | 
         
            +
                      value="299792458"
         
     | 
| 107 | 
         
            +
                      class="rounded-md border border-gray-700 p-2 text-right font-light dark:text-black"
         
     | 
| 108 | 
         
            +
                    />
         
     | 
| 109 | 
         
            +
                    <button
         
     | 
| 110 | 
         
            +
                      onclick="document.querySelector('#seed').value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER)"
         
     | 
| 111 | 
         
            +
                      class="button"
         
     | 
| 112 | 
         
            +
                    >
         
     | 
| 113 | 
         
            +
                      Rand
         
     | 
| 114 | 
         
            +
                    </button>
         
     | 
| 115 | 
         
            +
                  </div>
         
     | 
| 116 | 
         
            +
                </details>
         
     | 
| 117 | 
         
            +
              </div>
         
     | 
| 118 | 
         
            +
              <div class="flex gap-3">
         
     | 
| 119 | 
         
            +
                <button id="start" class="button"> Start </button>
         
     | 
| 120 | 
         
            +
                <button id="stop" class="button"> Stop </button>
         
     | 
| 121 | 
         
            +
                <button id="snap" disabled class="button ml-auto"> Snapshot </button>
         
     | 
| 122 | 
         
            +
              </div>
         
     | 
| 123 | 
         
            +
              <div class="relative overflow-hidden rounded-lg border border-slate-300">
         
     | 
| 124 | 
         
            +
                <img
         
     | 
| 125 | 
         
            +
                  id="player"
         
     | 
| 126 | 
         
            +
                  class="aspect-square w-full rounded-lg"
         
     | 
| 127 | 
         
            +
                  src=""
         
     | 
| 128 | 
         
            +
                />
         
     | 
| 129 | 
         
            +
                <div class="absolute left-0 top-0 aspect-square w-1/4">
         
     | 
| 130 | 
         
            +
                  <video
         
     | 
| 131 | 
         
            +
                    id="webcam"
         
     | 
| 132 | 
         
            +
                    class="relative z-10 aspect-square w-full object-cover"
         
     | 
| 133 | 
         
            +
                    playsinline
         
     | 
| 134 | 
         
            +
                    autoplay
         
     | 
| 135 | 
         
            +
                    muted
         
     | 
| 136 | 
         
            +
                    loop
         
     | 
| 137 | 
         
            +
                  />
         
     | 
| 138 | 
         
            +
                  <svg
         
     | 
| 139 | 
         
            +
                    xmlns="http://www.w3.org/2000/svg"
         
     | 
| 140 | 
         
            +
                    viewBox="0 0 448 448"
         
     | 
| 141 | 
         
            +
                    width="100"
         
     | 
| 142 | 
         
            +
                    class="absolute top-0 z-0 w-full p-4 opacity-20"
         
     | 
| 143 | 
         
            +
                  >
         
     | 
| 144 | 
         
            +
                    <path
         
     | 
| 145 | 
         
            +
                      fill="currentColor"
         
     | 
| 146 | 
         
            +
                      d="M224 256a128 128 0 1 0 0-256 128 128 0 1 0 0 256zm-45.7 48A178.3 178.3 0 0 0 0 482.3 29.7 29.7 0 0 0 29.7 512h388.6a29.7 29.7 0 0 0 29.7-29.7c0-98.5-79.8-178.3-178.3-178.3h-91.4z"
         
     | 
| 147 | 
         
            +
                    />
         
     | 
| 148 | 
         
            +
                  </svg>
         
     | 
| 149 | 
         
            +
                </div>
         
     | 
| 150 | 
         
            +
              </div>
         
     | 
| 151 | 
         
            +
            </main>
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
            <style lang="postcss">
         
     | 
| 154 | 
         
            +
              :global(html) {
         
     | 
| 155 | 
         
            +
                @apply text-black dark:bg-gray-900 dark:text-white;
         
     | 
| 156 | 
         
            +
              }
         
     | 
| 157 | 
         
            +
              .button {
         
     | 
| 158 | 
         
            +
                @apply rounded bg-gray-700 p-2 font-normal text-white hover:bg-gray-800 disabled:cursor-not-allowed disabled:bg-gray-300 dark:disabled:bg-gray-700 dark:disabled:text-black;
         
     | 
| 159 | 
         
            +
              }
         
     | 
| 160 | 
         
            +
            </style>
         
     | 
    	
        frontend/src/routes/+page.ts
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            export const prerender = true
         
     | 
    	
        frontend/static/favicon.png
    ADDED
    
    | 
											 | 
									
								
    	
        frontend/svelte.config.js
    ADDED
    
    | 
         @@ -0,0 +1,19 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import adapter from '@sveltejs/adapter-static';
         
     | 
| 2 | 
         
            +
            import { vitePreprocess } from '@sveltejs/kit/vite';
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            /** @type {import('@sveltejs/kit').Config} */
         
     | 
| 5 | 
         
            +
            const config = {
         
     | 
| 6 | 
         
            +
              preprocess: vitePreprocess(),
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
              kit: {
         
     | 
| 9 | 
         
            +
                adapter: adapter({
         
     | 
| 10 | 
         
            +
                  pages: '../public',
         
     | 
| 11 | 
         
            +
                  assets: '../public',
         
     | 
| 12 | 
         
            +
                  fallback: undefined,
         
     | 
| 13 | 
         
            +
                  precompress: false,
         
     | 
| 14 | 
         
            +
                  strict: true
         
     | 
| 15 | 
         
            +
                })
         
     | 
| 16 | 
         
            +
              }
         
     | 
| 17 | 
         
            +
            };
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            export default config;
         
     | 
    	
        frontend/tailwind.config.js
    ADDED
    
    | 
         @@ -0,0 +1,8 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /** @type {import('tailwindcss').Config} */
         
     | 
| 2 | 
         
            +
            export default {
         
     | 
| 3 | 
         
            +
              content: ['./src/**/*.{html,js,svelte,ts}'],
         
     | 
| 4 | 
         
            +
              theme: {
         
     | 
| 5 | 
         
            +
                extend: {}
         
     | 
| 6 | 
         
            +
              },
         
     | 
| 7 | 
         
            +
              plugins: []
         
     | 
| 8 | 
         
            +
            };
         
     | 
    	
        frontend/tsconfig.json
    ADDED
    
    | 
         @@ -0,0 +1,17 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "extends": "./.svelte-kit/tsconfig.json",
         
     | 
| 3 | 
         
            +
              "compilerOptions": {
         
     | 
| 4 | 
         
            +
                "allowJs": true,
         
     | 
| 5 | 
         
            +
                "checkJs": true,
         
     | 
| 6 | 
         
            +
                "esModuleInterop": true,
         
     | 
| 7 | 
         
            +
                "forceConsistentCasingInFileNames": true,
         
     | 
| 8 | 
         
            +
                "resolveJsonModule": true,
         
     | 
| 9 | 
         
            +
                "skipLibCheck": true,
         
     | 
| 10 | 
         
            +
                "sourceMap": true,
         
     | 
| 11 | 
         
            +
                "strict": true
         
     | 
| 12 | 
         
            +
              }
         
     | 
| 13 | 
         
            +
              // Path aliases are handled by https://kit.svelte.dev/docs/configuration#alias
         
     | 
| 14 | 
         
            +
              //
         
     | 
| 15 | 
         
            +
              // If you want to overwrite includes/excludes, make sure to copy over the relevant includes/excludes
         
     | 
| 16 | 
         
            +
              // from the referenced tsconfig.json - TypeScript does not merge them in
         
     | 
| 17 | 
         
            +
            }
         
     | 
    	
        frontend/vite.config.ts
    ADDED
    
    | 
         @@ -0,0 +1,6 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import { sveltekit } from '@sveltejs/kit/vite';
         
     | 
| 2 | 
         
            +
            import { defineConfig } from 'vite';
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            export default defineConfig({
         
     | 
| 5 | 
         
            +
              plugins: [sveltekit()]
         
     | 
| 6 | 
         
            +
            });
         
     | 
    	
        pipelines/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        pipelines/controlnet.py
    ADDED
    
    | 
         @@ -0,0 +1,90 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from diffusers import DiffusionPipeline, AutoencoderTiny
         
     | 
| 2 | 
         
            +
            from latent_consistency_controlnet import LatentConsistencyModelPipeline_controlnet
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from compel import Compel
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            try:
         
     | 
| 8 | 
         
            +
                import intel_extension_for_pytorch as ipex  # type: ignore
         
     | 
| 9 | 
         
            +
            except:
         
     | 
| 10 | 
         
            +
                pass
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            import psutil
         
     | 
| 13 | 
         
            +
            from config import Args
         
     | 
| 14 | 
         
            +
            from pydantic import BaseModel
         
     | 
| 15 | 
         
            +
            from PIL import Image
         
     | 
| 16 | 
         
            +
            from typing import Callable
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            base_model = "SimianLuo/LCM_Dreamshaper_v7"
         
     | 
| 19 | 
         
            +
            WIDTH = 512
         
     | 
| 20 | 
         
            +
            HEIGHT = 512
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            class Pipeline:
         
     | 
| 24 | 
         
            +
                class InputParams(BaseModel):
         
     | 
| 25 | 
         
            +
                    seed: int = 2159232
         
     | 
| 26 | 
         
            +
                    prompt: str
         
     | 
| 27 | 
         
            +
                    guidance_scale: float = 8.0
         
     | 
| 28 | 
         
            +
                    strength: float = 0.5
         
     | 
| 29 | 
         
            +
                    steps: int = 4
         
     | 
| 30 | 
         
            +
                    lcm_steps: int = 50
         
     | 
| 31 | 
         
            +
                    width: int = WIDTH
         
     | 
| 32 | 
         
            +
                    height: int = HEIGHT
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                @staticmethod
         
     | 
| 35 | 
         
            +
                def create_pipeline(
         
     | 
| 36 | 
         
            +
                    args: Args, device: torch.device, torch_dtype: torch.dtype
         
     | 
| 37 | 
         
            +
                ) -> Callable[["Pipeline.InputParams"], Image.Image]:
         
     | 
| 38 | 
         
            +
                    if args.safety_checker:
         
     | 
| 39 | 
         
            +
                        pipe = DiffusionPipeline.from_pretrained(base_model)
         
     | 
| 40 | 
         
            +
                    else:
         
     | 
| 41 | 
         
            +
                        pipe = DiffusionPipeline.from_pretrained(base_model, safety_checker=None)
         
     | 
| 42 | 
         
            +
                    if args.use_taesd:
         
     | 
| 43 | 
         
            +
                        pipe.vae = AutoencoderTiny.from_pretrained(
         
     | 
| 44 | 
         
            +
                            "madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
         
     | 
| 45 | 
         
            +
                        )
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                    pipe.set_progress_bar_config(disable=True)
         
     | 
| 48 | 
         
            +
                    pipe.to(device=device, dtype=torch_dtype)
         
     | 
| 49 | 
         
            +
                    pipe.unet.to(memory_format=torch.channels_last)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    # check if computer has less than 64GB of RAM using sys or os
         
     | 
| 52 | 
         
            +
                    if psutil.virtual_memory().total < 64 * 1024**3:
         
     | 
| 53 | 
         
            +
                        pipe.enable_attention_slicing()
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    if args.torch_compile:
         
     | 
| 56 | 
         
            +
                        pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
         
     | 
| 57 | 
         
            +
                        pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                        pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    compel_proc = Compel(
         
     | 
| 62 | 
         
            +
                        tokenizer=pipe.tokenizer,
         
     | 
| 63 | 
         
            +
                        text_encoder=pipe.text_encoder,
         
     | 
| 64 | 
         
            +
                        truncate_long_prompts=False,
         
     | 
| 65 | 
         
            +
                    )
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                    def predict(params: "Pipeline.InputParams") -> Image.Image:
         
     | 
| 68 | 
         
            +
                        generator = torch.manual_seed(params.seed)
         
     | 
| 69 | 
         
            +
                        prompt_embeds = compel_proc(params.prompt)
         
     | 
| 70 | 
         
            +
                        # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
         
     | 
| 71 | 
         
            +
                        results = pipe(
         
     | 
| 72 | 
         
            +
                            prompt_embeds=prompt_embeds,
         
     | 
| 73 | 
         
            +
                            generator=generator,
         
     | 
| 74 | 
         
            +
                            num_inference_steps=params.steps,
         
     | 
| 75 | 
         
            +
                            guidance_scale=params.guidance_scale,
         
     | 
| 76 | 
         
            +
                            width=params.width,
         
     | 
| 77 | 
         
            +
                            height=params.height,
         
     | 
| 78 | 
         
            +
                            original_inference_steps=params.lcm_steps,
         
     | 
| 79 | 
         
            +
                            output_type="pil",
         
     | 
| 80 | 
         
            +
                        )
         
     | 
| 81 | 
         
            +
                        nsfw_content_detected = (
         
     | 
| 82 | 
         
            +
                            results.nsfw_content_detected[0]
         
     | 
| 83 | 
         
            +
                            if "nsfw_content_detected" in results
         
     | 
| 84 | 
         
            +
                            else False
         
     | 
| 85 | 
         
            +
                        )
         
     | 
| 86 | 
         
            +
                        if nsfw_content_detected:
         
     | 
| 87 | 
         
            +
                            return None
         
     | 
| 88 | 
         
            +
                        return results.images[0]
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    return predict
         
     | 
    	
        pipelines/txt2img.py
    ADDED
    
    | 
         @@ -0,0 +1,85 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from diffusers import DiffusionPipeline, AutoencoderTiny
         
     | 
| 2 | 
         
            +
            from compel import Compel
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            try:
         
     | 
| 6 | 
         
            +
                import intel_extension_for_pytorch as ipex  # type: ignore
         
     | 
| 7 | 
         
            +
            except:
         
     | 
| 8 | 
         
            +
                pass
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            import psutil
         
     | 
| 11 | 
         
            +
            from config import Args
         
     | 
| 12 | 
         
            +
            from pydantic import BaseModel
         
     | 
| 13 | 
         
            +
            from PIL import Image
         
     | 
| 14 | 
         
            +
            from typing import Callable
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            base_model = "SimianLuo/LCM_Dreamshaper_v7"
         
     | 
| 17 | 
         
            +
            taesd_model = "madebyollin/taesd"
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            class Pipeline:
         
     | 
| 21 | 
         
            +
                class InputParams(BaseModel):
         
     | 
| 22 | 
         
            +
                    seed: int = 2159232
         
     | 
| 23 | 
         
            +
                    prompt: str = ""
         
     | 
| 24 | 
         
            +
                    guidance_scale: float = 8.0
         
     | 
| 25 | 
         
            +
                    strength: float = 0.5
         
     | 
| 26 | 
         
            +
                    steps: int = 4
         
     | 
| 27 | 
         
            +
                    width: int = 512
         
     | 
| 28 | 
         
            +
                    height: int = 512
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
         
     | 
| 31 | 
         
            +
                    if args.safety_checker:
         
     | 
| 32 | 
         
            +
                        self.pipe = DiffusionPipeline.from_pretrained(base_model)
         
     | 
| 33 | 
         
            +
                    else:
         
     | 
| 34 | 
         
            +
                        self.pipe = DiffusionPipeline.from_pretrained(
         
     | 
| 35 | 
         
            +
                            base_model, safety_checker=None
         
     | 
| 36 | 
         
            +
                        )
         
     | 
| 37 | 
         
            +
                    if args.use_taesd:
         
     | 
| 38 | 
         
            +
                        self.pipe.vae = AutoencoderTiny.from_pretrained(
         
     | 
| 39 | 
         
            +
                            taesd_model, torch_dtype=torch_dtype, use_safetensors=True
         
     | 
| 40 | 
         
            +
                        )
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    self.pipe.set_progress_bar_config(disable=True)
         
     | 
| 43 | 
         
            +
                    self.pipe.to(device=device, dtype=torch_dtype)
         
     | 
| 44 | 
         
            +
                    self.pipe.unet.to(memory_format=torch.channels_last)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    # check if computer has less than 64GB of RAM using sys or os
         
     | 
| 47 | 
         
            +
                    if psutil.virtual_memory().total < 64 * 1024**3:
         
     | 
| 48 | 
         
            +
                        self.pipe.enable_attention_slicing()
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    if args.torch_compile:
         
     | 
| 51 | 
         
            +
                        self.pipe.unet = torch.compile(
         
     | 
| 52 | 
         
            +
                            self.pipe.unet, mode="reduce-overhead", fullgraph=True
         
     | 
| 53 | 
         
            +
                        )
         
     | 
| 54 | 
         
            +
                        self.pipe.vae = torch.compile(
         
     | 
| 55 | 
         
            +
                            self.pipe.vae, mode="reduce-overhead", fullgraph=True
         
     | 
| 56 | 
         
            +
                        )
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                        self.pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    self.compel_proc = Compel(
         
     | 
| 61 | 
         
            +
                        tokenizer=self.pipe.tokenizer,
         
     | 
| 62 | 
         
            +
                        text_encoder=self.pipe.text_encoder,
         
     | 
| 63 | 
         
            +
                        truncate_long_prompts=False,
         
     | 
| 64 | 
         
            +
                    )
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                def predict(self, params: "Pipeline.InputParams") -> Image.Image:
         
     | 
| 67 | 
         
            +
                    generator = torch.manual_seed(params.seed)
         
     | 
| 68 | 
         
            +
                    prompt_embeds = self.compel_proc(params.prompt)
         
     | 
| 69 | 
         
            +
                    results = self.pipe(
         
     | 
| 70 | 
         
            +
                        prompt_embeds=prompt_embeds,
         
     | 
| 71 | 
         
            +
                        generator=generator,
         
     | 
| 72 | 
         
            +
                        num_inference_steps=params.steps,
         
     | 
| 73 | 
         
            +
                        guidance_scale=params.guidance_scale,
         
     | 
| 74 | 
         
            +
                        width=params.width,
         
     | 
| 75 | 
         
            +
                        height=params.height,
         
     | 
| 76 | 
         
            +
                        output_type="pil",
         
     | 
| 77 | 
         
            +
                    )
         
     | 
| 78 | 
         
            +
                    nsfw_content_detected = (
         
     | 
| 79 | 
         
            +
                        results.nsfw_content_detected[0]
         
     | 
| 80 | 
         
            +
                        if "nsfw_content_detected" in results
         
     | 
| 81 | 
         
            +
                        else False
         
     | 
| 82 | 
         
            +
                    )
         
     | 
| 83 | 
         
            +
                    if nsfw_content_detected:
         
     | 
| 84 | 
         
            +
                        return None
         
     | 
| 85 | 
         
            +
                    return results.images[0]
         
     | 
    	
        pipelines/txt2imglora.py
    ADDED
    
    | 
         @@ -0,0 +1,93 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from diffusers import DiffusionPipeline, AutoencoderTiny
         
     | 
| 2 | 
         
            +
            from compel import Compel
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            try:
         
     | 
| 6 | 
         
            +
                import intel_extension_for_pytorch as ipex  # type: ignore
         
     | 
| 7 | 
         
            +
            except:
         
     | 
| 8 | 
         
            +
                pass
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            import psutil
         
     | 
| 11 | 
         
            +
            from config import Args
         
     | 
| 12 | 
         
            +
            from pydantic import BaseModel
         
     | 
| 13 | 
         
            +
            from PIL import Image
         
     | 
| 14 | 
         
            +
            from typing import Callable
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            base_model = "SimianLuo/LCM_Dreamshaper_v7"
         
     | 
| 17 | 
         
            +
            WIDTH = 512
         
     | 
| 18 | 
         
            +
            HEIGHT = 512
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            model_id = "wavymulder/Analog-Diffusion"
         
     | 
| 21 | 
         
            +
            lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            class Pipeline:
         
     | 
| 25 | 
         
            +
                class InputParams(BaseModel):
         
     | 
| 26 | 
         
            +
                    seed: int = 2159232
         
     | 
| 27 | 
         
            +
                    prompt: str
         
     | 
| 28 | 
         
            +
                    guidance_scale: float = 8.0
         
     | 
| 29 | 
         
            +
                    strength: float = 0.5
         
     | 
| 30 | 
         
            +
                    steps: int = 4
         
     | 
| 31 | 
         
            +
                    lcm_steps: int = 50
         
     | 
| 32 | 
         
            +
                    width: int = WIDTH
         
     | 
| 33 | 
         
            +
                    height: int = HEIGHT
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                @staticmethod
         
     | 
| 36 | 
         
            +
                def create_pipeline(
         
     | 
| 37 | 
         
            +
                    args: Args, device: torch.device, torch_dtype: torch.dtype
         
     | 
| 38 | 
         
            +
                ) -> Callable[["Pipeline.InputParams"], Image.Image]:
         
     | 
| 39 | 
         
            +
                    if args.safety_checker:
         
     | 
| 40 | 
         
            +
                        pipe = DiffusionPipeline.from_pretrained(base_model)
         
     | 
| 41 | 
         
            +
                    else:
         
     | 
| 42 | 
         
            +
                        pipe = DiffusionPipeline.from_pretrained(base_model, safety_checker=None)
         
     | 
| 43 | 
         
            +
                    if args.use_taesd:
         
     | 
| 44 | 
         
            +
                        pipe.vae = AutoencoderTiny.from_pretrained(
         
     | 
| 45 | 
         
            +
                            "madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
         
     | 
| 46 | 
         
            +
                        )
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    pipe.set_progress_bar_config(disable=True)
         
     | 
| 49 | 
         
            +
                    pipe.to(device=device, dtype=torch_dtype)
         
     | 
| 50 | 
         
            +
                    pipe.unet.to(memory_format=torch.channels_last)
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    # Load LCM LoRA
         
     | 
| 53 | 
         
            +
                    pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm")
         
     | 
| 54 | 
         
            +
                    # check if computer has less than 64GB of RAM using sys or os
         
     | 
| 55 | 
         
            +
                    if psutil.virtual_memory().total < 64 * 1024**3:
         
     | 
| 56 | 
         
            +
                        pipe.enable_attention_slicing()
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    if args.torch_compile:
         
     | 
| 59 | 
         
            +
                        pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
         
     | 
| 60 | 
         
            +
                        pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                        pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    compel_proc = Compel(
         
     | 
| 65 | 
         
            +
                        tokenizer=pipe.tokenizer,
         
     | 
| 66 | 
         
            +
                        text_encoder=pipe.text_encoder,
         
     | 
| 67 | 
         
            +
                        truncate_long_prompts=False,
         
     | 
| 68 | 
         
            +
                    )
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    def predict(params: "Pipeline.InputParams") -> Image.Image:
         
     | 
| 71 | 
         
            +
                        generator = torch.manual_seed(params.seed)
         
     | 
| 72 | 
         
            +
                        prompt_embeds = compel_proc(params.prompt)
         
     | 
| 73 | 
         
            +
                        # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
         
     | 
| 74 | 
         
            +
                        results = pipe(
         
     | 
| 75 | 
         
            +
                            prompt_embeds=prompt_embeds,
         
     | 
| 76 | 
         
            +
                            generator=generator,
         
     | 
| 77 | 
         
            +
                            num_inference_steps=params.steps,
         
     | 
| 78 | 
         
            +
                            guidance_scale=params.guidance_scale,
         
     | 
| 79 | 
         
            +
                            width=params.width,
         
     | 
| 80 | 
         
            +
                            height=params.height,
         
     | 
| 81 | 
         
            +
                            original_inference_steps=params.lcm_steps,
         
     | 
| 82 | 
         
            +
                            output_type="pil",
         
     | 
| 83 | 
         
            +
                        )
         
     | 
| 84 | 
         
            +
                        nsfw_content_detected = (
         
     | 
| 85 | 
         
            +
                            results.nsfw_content_detected[0]
         
     | 
| 86 | 
         
            +
                            if "nsfw_content_detected" in results
         
     | 
| 87 | 
         
            +
                            else False
         
     | 
| 88 | 
         
            +
                        )
         
     | 
| 89 | 
         
            +
                        if nsfw_content_detected:
         
     | 
| 90 | 
         
            +
                            return None
         
     | 
| 91 | 
         
            +
                        return results.images[0]
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    return predict
         
     | 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -3,8 +3,8 @@ transformers==4.34.1 
     | 
|
| 3 | 
         
             
            gradio==3.50.2
         
     | 
| 4 | 
         
             
            --extra-index-url https://download.pytorch.org/whl/cu121;
         
     | 
| 5 | 
         
             
            torch==2.1.0
         
     | 
| 6 | 
         
            -
            fastapi==0.104. 
     | 
| 7 | 
         
            -
            uvicorn==0. 
     | 
| 8 | 
         
             
            Pillow==10.1.0
         
     | 
| 9 | 
         
             
            accelerate==0.24.0
         
     | 
| 10 | 
         
             
            compel==2.0.2
         
     | 
| 
         | 
|
| 3 | 
         
             
            gradio==3.50.2
         
     | 
| 4 | 
         
             
            --extra-index-url https://download.pytorch.org/whl/cu121;
         
     | 
| 5 | 
         
             
            torch==2.1.0
         
     | 
| 6 | 
         
            +
            fastapi==0.104.1
         
     | 
| 7 | 
         
            +
            uvicorn==0.24.0.post1
         
     | 
| 8 | 
         
             
            Pillow==10.1.0
         
     | 
| 9 | 
         
             
            accelerate==0.24.0
         
     | 
| 10 | 
         
             
            compel==2.0.2
         
     | 
    	
        run.py
    ADDED
    
    | 
         @@ -0,0 +1,5 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 2 | 
         
            +
                import uvicorn
         
     | 
| 3 | 
         
            +
                from config import args
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
                uvicorn.run("app:app", host=args.host, port=args.port, reload=args.reload)
         
     | 
    	
        user_queue.py
    ADDED
    
    | 
         @@ -0,0 +1,18 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Dict, Union
         
     | 
| 2 | 
         
            +
            from uuid import UUID
         
     | 
| 3 | 
         
            +
            from asyncio import Queue
         
     | 
| 4 | 
         
            +
            from PIL import Image
         
     | 
| 5 | 
         
            +
            from typing import Tuple, Union
         
     | 
| 6 | 
         
            +
            from uuid import UUID
         
     | 
| 7 | 
         
            +
            from asyncio import Queue
         
     | 
| 8 | 
         
            +
            from PIL import Image
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            UserId = UUID
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            InputParams = dict
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            QueueContent = Dict[str, Union[Image.Image, InputParams]]
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            UserQueueDict = Dict[UserId, Queue[QueueContent]]
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            user_queue_map: UserQueueDict = {}
         
     | 
    	
        util.py
    ADDED
    
    | 
         @@ -0,0 +1,16 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from importlib import import_module
         
     | 
| 2 | 
         
            +
            from types import ModuleType
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            def get_pipeline_class(pipeline_name: str) -> ModuleType:
         
     | 
| 6 | 
         
            +
                try:
         
     | 
| 7 | 
         
            +
                    module = import_module(f"pipelines.{pipeline_name}")
         
     | 
| 8 | 
         
            +
                except ModuleNotFoundError:
         
     | 
| 9 | 
         
            +
                    raise ValueError(f"Pipeline {pipeline_name} module not found")
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
                pipeline_class = getattr(module, "Pipeline", None)
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                if pipeline_class is None:
         
     | 
| 14 | 
         
            +
                    raise ValueError(f"'Pipeline' class not found in module '{module_name}'.")
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                return pipeline_class
         
     |