Spaces:
Running
Running
| from collections import OrderedDict | |
| import os | |
| import sqlite3 | |
| import asyncio | |
| import concurrent.futures | |
| from extensions_built_in.sd_trainer.SDTrainer import SDTrainer | |
| from typing import Literal, Optional | |
| AITK_Status = Literal["running", "stopped", "error", "completed"] | |
| class UITrainer(SDTrainer): | |
| def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): | |
| super(UITrainer, self).__init__(process_id, job, config, **kwargs) | |
| self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db") | |
| if not os.path.exists(self.sqlite_db_path): | |
| raise Exception( | |
| f"SQLite database not found at {self.sqlite_db_path}") | |
| print(f"Using SQLite database at {self.sqlite_db_path}") | |
| self.job_id = os.environ.get("AITK_JOB_ID", None) | |
| self.job_id = self.job_id.strip() if self.job_id is not None else None | |
| print(f"Job ID: \"{self.job_id}\"") | |
| if self.job_id is None: | |
| raise Exception("AITK_JOB_ID not set") | |
| self.is_stopping = False | |
| # Create a thread pool for database operations | |
| self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) | |
| # Track all async tasks | |
| self._async_tasks = [] | |
| # Initialize the status | |
| self._run_async_operation(self._update_status("running", "Starting")) | |
| def _run_async_operation(self, coro): | |
| """Helper method to run an async coroutine and track the task.""" | |
| try: | |
| loop = asyncio.get_event_loop() | |
| except RuntimeError: | |
| # No event loop exists, create a new one | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| # Create a task and track it | |
| if loop.is_running(): | |
| task = asyncio.run_coroutine_threadsafe(coro, loop) | |
| self._async_tasks.append(asyncio.wrap_future(task)) | |
| else: | |
| task = loop.create_task(coro) | |
| self._async_tasks.append(task) | |
| loop.run_until_complete(task) | |
| async def _execute_db_operation(self, operation_func): | |
| """Execute a database operation in a separate thread to avoid blocking.""" | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(self.thread_pool, operation_func) | |
| def _db_connect(self): | |
| """Create a new connection for each operation to avoid locking.""" | |
| conn = sqlite3.connect(self.sqlite_db_path, timeout=10.0) | |
| conn.isolation_level = None # Enable autocommit mode | |
| return conn | |
| def should_stop(self): | |
| def _check_stop(): | |
| with self._db_connect() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "SELECT stop FROM Job WHERE id = ?", (self.job_id,)) | |
| stop = cursor.fetchone() | |
| return False if stop is None else stop[0] == 1 | |
| return _check_stop() | |
| def maybe_stop(self): | |
| if self.should_stop(): | |
| self._run_async_operation( | |
| self._update_status("stopped", "Job stopped")) | |
| self.is_stopping = True | |
| raise Exception("Job stopped") | |
| async def _update_key(self, key, value): | |
| if not self.accelerator.is_main_process: | |
| return | |
| def _do_update(): | |
| with self._db_connect() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("BEGIN IMMEDIATE") | |
| try: | |
| # Convert the value to string if it's not already | |
| if isinstance(value, str): | |
| value_to_insert = value | |
| else: | |
| value_to_insert = str(value) | |
| # Use parameterized query for both the column name and value | |
| update_query = f"UPDATE Job SET {key} = ? WHERE id = ?" | |
| cursor.execute( | |
| update_query, (value_to_insert, self.job_id)) | |
| finally: | |
| cursor.execute("COMMIT") | |
| await self._execute_db_operation(_do_update) | |
| def update_step(self): | |
| """Non-blocking update of the step count.""" | |
| if self.accelerator.is_main_process: | |
| self._run_async_operation(self._update_key("step", self.step_num)) | |
| def update_db_key(self, key, value): | |
| """Non-blocking update a key in the database.""" | |
| if self.accelerator.is_main_process: | |
| self._run_async_operation(self._update_key(key, value)) | |
| async def _update_status(self, status: AITK_Status, info: Optional[str] = None): | |
| if not self.accelerator.is_main_process: | |
| return | |
| def _do_update(): | |
| with self._db_connect() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("BEGIN IMMEDIATE") | |
| try: | |
| if info is not None: | |
| cursor.execute( | |
| "UPDATE Job SET status = ?, info = ? WHERE id = ?", | |
| (status, info, self.job_id) | |
| ) | |
| else: | |
| cursor.execute( | |
| "UPDATE Job SET status = ? WHERE id = ?", | |
| (status, self.job_id) | |
| ) | |
| finally: | |
| cursor.execute("COMMIT") | |
| await self._execute_db_operation(_do_update) | |
| def update_status(self, status: AITK_Status, info: Optional[str] = None): | |
| """Non-blocking update of status.""" | |
| if self.accelerator.is_main_process: | |
| self._run_async_operation(self._update_status(status, info)) | |
| async def wait_for_all_async(self): | |
| """Wait for all tracked async operations to complete.""" | |
| if not self._async_tasks: | |
| return | |
| try: | |
| await asyncio.gather(*self._async_tasks) | |
| except Exception as e: | |
| pass | |
| finally: | |
| # Clear the task list after completion | |
| self._async_tasks.clear() | |
| def on_error(self, e: Exception): | |
| super(UITrainer, self).on_error(e) | |
| if self.accelerator.is_main_process and not self.is_stopping: | |
| self.update_status("error", str(e)) | |
| self.update_db_key("step", self.last_save_step) | |
| asyncio.run(self.wait_for_all_async()) | |
| self.thread_pool.shutdown(wait=True) | |
| def handle_timing_print_hook(self, timing_dict): | |
| if "train_loop" not in timing_dict: | |
| print("train_loop not found in timing_dict", timing_dict) | |
| return | |
| seconds_per_iter = timing_dict["train_loop"] | |
| # determine iter/sec or sec/iter | |
| if seconds_per_iter < 1: | |
| iters_per_sec = 1 / seconds_per_iter | |
| self.update_db_key("speed_string", f"{iters_per_sec:.2f} iter/sec") | |
| else: | |
| self.update_db_key( | |
| "speed_string", f"{seconds_per_iter:.2f} sec/iter") | |
| def done_hook(self): | |
| super(UITrainer, self).done_hook() | |
| self.update_status("completed", "Training completed") | |
| # Wait for all async operations to finish before shutting down | |
| asyncio.run(self.wait_for_all_async()) | |
| self.thread_pool.shutdown(wait=True) | |
| def end_step_hook(self): | |
| super(UITrainer, self).end_step_hook() | |
| self.update_step() | |
| self.maybe_stop() | |
| def hook_before_model_load(self): | |
| super().hook_before_model_load() | |
| self.maybe_stop() | |
| self.update_status("running", "Loading model") | |
| def before_dataset_load(self): | |
| super().before_dataset_load() | |
| self.maybe_stop() | |
| self.update_status("running", "Loading dataset") | |
| def hook_before_train_loop(self): | |
| super().hook_before_train_loop() | |
| self.maybe_stop() | |
| self.update_step() | |
| self.update_status("running", "Training") | |
| self.timer.add_after_print_hook(self.handle_timing_print_hook) | |
| def status_update_hook_func(self, string): | |
| self.update_status("running", string) | |
| def hook_after_sd_init_before_load(self): | |
| super().hook_after_sd_init_before_load() | |
| self.maybe_stop() | |
| self.sd.add_status_update_hook(self.status_update_hook_func) | |
| def sample_step_hook(self, img_num, total_imgs): | |
| super().sample_step_hook(img_num, total_imgs) | |
| self.maybe_stop() | |
| self.update_status( | |
| "running", f"Generating images - {img_num + 1}/{total_imgs}") | |
| def sample(self, step=None, is_first=False): | |
| self.maybe_stop() | |
| total_imgs = len(self.sample_config.prompts) | |
| self.update_status("running", f"Generating images - 0/{total_imgs}") | |
| super().sample(step, is_first) | |
| self.maybe_stop() | |
| self.update_status("running", "Training") | |
| def save(self, step=None): | |
| self.maybe_stop() | |
| self.update_status("running", "Saving model") | |
| super().save(step) | |
| self.maybe_stop() | |
| self.update_status("running", "Training") | |