cbensimon HF Staff commited on
Commit
f0b5714
·
1 Parent(s): c79e236
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pyc
spaces/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ import sys
5
+
6
+
7
+ if sys.version_info.minor < 8: # pragma: no cover
8
+ raise RuntimeError("Importing PySpaces requires Python 3.8+")
9
+
10
+
11
+ # Prevent gradio from importing spaces
12
+ if (gr := sys.modules.get('gradio')) is not None: # pragma: no cover
13
+ try:
14
+ gr.Blocks
15
+ except AttributeError:
16
+ raise ImportError
17
+
18
+
19
+ from .zero.decorator import GPU
20
+ from .gradio import gradio_auto_wrap
21
+ from .gradio import disable_gradio_auto_wrap
22
+ from .gradio import enable_gradio_auto_wrap
23
+
24
+
25
+ __all__ = [
26
+ 'GPU',
27
+ 'gradio_auto_wrap',
28
+ 'disable_gradio_auto_wrap',
29
+ 'enable_gradio_auto_wrap',
30
+ ]
spaces/config.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from pathlib import Path
7
+
8
+
9
+ ZEROGPU_OFFLOAD_DIR_DEFAULT = str(Path.home() / '.zerogpu' / 'tensors')
10
+
11
+
12
+ def boolean(value: str | None) -> bool:
13
+ return value is not None and value.lower() in ("1", "t", "true")
14
+
15
+
16
+ class Settings:
17
+ def __init__(self):
18
+ self.zero_gpu = boolean(
19
+ os.getenv('SPACES_ZERO_GPU'))
20
+ self.zero_device_api_url = (
21
+ os.getenv('SPACES_ZERO_DEVICE_API_URL'))
22
+ self.gradio_auto_wrap = boolean(
23
+ os.getenv('SPACES_GRADIO_AUTO_WRAP'))
24
+ self.zero_patch_torch_device = boolean(
25
+ os.getenv('ZERO_GPU_PATCH_TORCH_DEVICE'))
26
+ self.zero_gpu_v2 = boolean(
27
+ os.getenv('ZEROGPU_V2'))
28
+ self.zerogpu_offload_dir = (
29
+ os.getenv('ZEROGPU_OFFLOAD_DIR', ZEROGPU_OFFLOAD_DIR_DEFAULT))
30
+ self.zerogpu_proc_self_cgroup_path = (
31
+ os.getenv('ZEROGPU_PROC_SELF_CGROUP_PATH', '/proc/self/cgroup'))
32
+ self.zerogpu_cuda_device_name = (
33
+ os.getenv('ZEROGPU_CUDA_DEVICE_NAME', "NVIDIA H200 MIG 3g.71gb"))
34
+ self.zerogpu_cuda_total_memory = int(
35
+ os.getenv('ZEROGPU_CUDA_TOTAL_MEMORY', 74625056768))
36
+ self.zerogpu_cuda_reserved_memory = int(
37
+ os.getenv('ZEROGPU_CUDA_RESERVED_MEMORY', 0))
38
+ self.zerogpu_cuda_capability_major = int(
39
+ os.getenv('ZEROGPU_CUDA_CAPABILITY_MAJOR', 9))
40
+ self.zerogpu_cuda_capability_minor = int(
41
+ os.getenv('ZEROGPU_CUDA_CAPABILITY_MINOR', 0))
42
+ self.zerogpu_cuda_multi_processor_count = int(
43
+ os.getenv('ZEROGPU_CUDA_MULTI_PROCESSOR_COUNT', 60))
44
+
45
+
46
+ Config = Settings()
47
+
48
+
49
+ if Config.zero_gpu:
50
+ assert Config.zero_device_api_url is not None, (
51
+ 'SPACES_ZERO_DEVICE_API_URL env must be set '
52
+ 'on ZeroGPU Spaces (identified by SPACES_ZERO_GPU=true)'
53
+ )
spaces/gradio.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ from typing import Callable
6
+ from typing import Generator
7
+ from typing import TypeVar
8
+ from typing import overload
9
+ from typing_extensions import ParamSpec
10
+
11
+ from .config import Config
12
+ from .zero.decorator import GPU
13
+
14
+
15
+ Param = ParamSpec('Param')
16
+ Res = TypeVar('Res')
17
+
18
+
19
+ gradio_auto_wrap_enabled = Config.gradio_auto_wrap
20
+
21
+
22
+ def disable_gradio_auto_wrap():
23
+ global gradio_auto_wrap_enabled
24
+ gradio_auto_wrap_enabled = False
25
+
26
+ def enable_gradio_auto_wrap():
27
+ global gradio_auto_wrap_enabled
28
+ gradio_auto_wrap_enabled = True
29
+
30
+
31
+ @overload
32
+ def gradio_auto_wrap(
33
+ task:
34
+ Callable[Param, Res],
35
+ ) -> Callable[Param, Res]:
36
+ ...
37
+ @overload
38
+ def gradio_auto_wrap(
39
+ task:
40
+ None,
41
+ ) -> None:
42
+ ...
43
+ def gradio_auto_wrap(
44
+ task:
45
+ Callable[Param, Res]
46
+ | None,
47
+ ) -> (Callable[Param, Res]
48
+ | None):
49
+ """
50
+ """
51
+ if not gradio_auto_wrap_enabled:
52
+ return task
53
+ if not callable(task):
54
+ return task
55
+ return GPU(task) # type: ignore
spaces/utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import base64
6
+ import ctypes
7
+ import json
8
+ import sys
9
+ from functools import lru_cache as cache
10
+ from functools import partial
11
+ from typing import Any
12
+
13
+ import multiprocessing
14
+ from multiprocessing.queues import SimpleQueue as _SimpleQueue
15
+ from pathlib import Path
16
+ from pickle import PicklingError
17
+ from typing import Callable
18
+ from typing import TypeVar
19
+
20
+ from .config import Config
21
+
22
+
23
+ GRADIO_VERSION_ERROR_MESSAGE = "Make sure Gradio version is at least 3.46"
24
+
25
+
26
+ T = TypeVar('T')
27
+
28
+
29
+ @cache
30
+ def self_cgroup_device_path() -> str:
31
+ cgroup_content = Path(Config.zerogpu_proc_self_cgroup_path).read_text()
32
+ for line in cgroup_content.strip().split('\n'):
33
+ contents = line.split(':devices:')
34
+ if len(contents) != 2:
35
+ continue # pragma: no cover
36
+ return contents[1]
37
+ raise Exception # pragma: no cover
38
+
39
+
40
+ if sys.version_info.minor < 9: # pragma: no cover
41
+ _SimpleQueue.__class_getitem__ = classmethod(lambda cls, _: cls) # type: ignore
42
+
43
+ class SimpleQueue(_SimpleQueue[T]):
44
+ def __init__(self, *args):
45
+ super().__init__(*args, ctx=multiprocessing.get_context('fork'))
46
+ def put(self, obj: T):
47
+ try:
48
+ super().put(obj)
49
+ except PicklingError:
50
+ raise # pragma: no cover
51
+ # https://bugs.python.org/issue29187
52
+ except Exception as e:
53
+ message = str(e)
54
+ if not "pickle" in message:
55
+ raise # pragma: no cover
56
+ raise PicklingError(message)
57
+ def close(self): # Python 3.8 static typing trick
58
+ super().close() # type: ignore
59
+ def wlock_release(self):
60
+ if (lock := getattr(self, '_wlock', None)) is None:
61
+ return # pragma: no cover
62
+ try:
63
+ lock.release()
64
+ except ValueError:
65
+ pass
66
+
67
+
68
+ def drop_params(fn: Callable[[], T]) -> Callable[..., T]:
69
+ def drop(*args):
70
+ return fn()
71
+ return drop
72
+
73
+
74
+ def gradio_request_var():
75
+ try:
76
+ from gradio.context import LocalContext
77
+ except ImportError: # pragma: no cover
78
+ raise RuntimeError(GRADIO_VERSION_ERROR_MESSAGE)
79
+ return LocalContext.request
80
+
81
+
82
+ def malloc_trim():
83
+ ctypes.CDLL("libc.so.6").malloc_trim(0)
84
+
85
+
86
+ debug = partial(print, 'SPACES_ZERO_GPU_DEBUG')
87
+
88
+
89
+ def jwt_payload(token: str) -> dict[str, Any]:
90
+ _, payload, _ = token.split('.')
91
+ return json.loads(base64.urlsafe_b64decode(f'{payload}=='))
spaces/zero/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from pathlib import Path
5
+
6
+ from ..config import Config
7
+
8
+
9
+ if Config.zero_gpu:
10
+
11
+ from . import gradio
12
+ from . import torch
13
+
14
+ if torch.is_in_bad_fork():
15
+ raise RuntimeError(
16
+ "CUDA has been initialized before importing the `spaces` package"
17
+ )
18
+
19
+ torch.patch()
20
+ gradio.one_launch(torch.pack)
21
+ Path(Config.zerogpu_offload_dir).mkdir(parents=True, exist_ok=True)
spaces/zero/api.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Synced with huggingface/pyspaces:spaces/zero/api.py
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from datetime import timedelta
7
+ from typing import Any
8
+ from typing import Generator
9
+ from typing import Literal
10
+ from typing import NamedTuple
11
+ from typing import Optional
12
+ from typing import overload
13
+
14
+ import httpx
15
+ from pydantic import BaseModel
16
+ from typing_extensions import assert_never
17
+
18
+
19
+ AllowToken = str
20
+ NvidiaIndex = int # TODO: Migrate to GpuIndex (less confusing for MIG)
21
+ NvidiaUUID = str
22
+ CGroupPath = str
23
+ TaskId = int
24
+
25
+ AuthLevel = Literal['regular', 'pro']
26
+ QueuingReason = Literal['node', 'concurrency']
27
+
28
+
29
+ AUTHENTICATED_HEADER = 'X-Authenticated'
30
+ QUEUING_REASON_HEADER = 'X-Queuing-Reason'
31
+
32
+
33
+ class ScheduleResponse(BaseModel):
34
+ idle: bool
35
+ nvidiaIndex: int
36
+ nvidiaUUID: str
37
+ allowToken: str
38
+
39
+
40
+ class ScheduleMetadata(BaseModel):
41
+ auth: Optional[AuthLevel] = None
42
+ queuing_reason: Optional[QueuingReason] = None
43
+
44
+
45
+ class QuotaInfos(BaseModel):
46
+ left: int
47
+ wait: timedelta
48
+
49
+
50
+ class QueueEvent(BaseModel):
51
+ event: Literal['ping', 'failed', 'succeeded']
52
+ data: Optional[ScheduleResponse] = None
53
+
54
+
55
+ def sse_parse(text: str):
56
+ event, *data = text.strip().splitlines()
57
+ assert event.startswith('event:')
58
+ event = event[6:].strip()
59
+ if event in ('ping', 'failed'):
60
+ return QueueEvent(event=event)
61
+ assert event == 'succeeded'
62
+ (data,) = data
63
+ assert data.startswith('data:')
64
+ data = data[5:].strip()
65
+ return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data))
66
+
67
+
68
+ def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]:
69
+ for text in res.iter_text():
70
+ if len(text) == 0:
71
+ break # pragma: no cover
72
+ try:
73
+ yield sse_parse(text)
74
+ except GeneratorExit:
75
+ res.close()
76
+ break
77
+
78
+
79
+ class APIClient:
80
+
81
+ def __init__(self, client: httpx.Client):
82
+ self.client = client
83
+
84
+ def startup_report(self) -> httpx.codes:
85
+ res = self.client.post('/startup-report')
86
+ return httpx.codes(res.status_code)
87
+
88
+ def schedule(
89
+ self,
90
+ cgroup_path: str,
91
+ task_id: int = 0,
92
+ token: str | None = None,
93
+ token_version: int = 1,
94
+ duration_seconds: int | None = None,
95
+ enable_queue: bool = True,
96
+ ):
97
+ params: dict[str, str | int | bool] = {
98
+ 'cgroupPath': cgroup_path,
99
+ 'taskId': task_id,
100
+ 'enableQueue': enable_queue,
101
+ 'tokenVersion': token_version,
102
+ }
103
+ if duration_seconds is not None:
104
+ params['durationSeconds'] = duration_seconds
105
+ if token is not None:
106
+ params['token'] = token
107
+ res = self.client.send(
108
+ request=self.client.build_request(
109
+ method='POST',
110
+ url='/schedule',
111
+ params=params,
112
+ ),
113
+ stream=True,
114
+ )
115
+ status = httpx.codes(res.status_code)
116
+ auth: AuthLevel | None = res.headers.get(AUTHENTICATED_HEADER)
117
+ queuing_reason: QueuingReason | None = res.headers.get(QUEUING_REASON_HEADER)
118
+ metadata = ScheduleMetadata(auth=auth, queuing_reason=queuing_reason)
119
+ if (status is not httpx.codes.OK and
120
+ status is not httpx.codes.TOO_MANY_REQUESTS
121
+ ):
122
+ res.close()
123
+ return status, metadata
124
+ if "text/event-stream" in res.headers['content-type']:
125
+ return sse_stream(res), metadata
126
+ res.read()
127
+ if status is httpx.codes.TOO_MANY_REQUESTS:
128
+ return QuotaInfos(**res.json()), metadata # pragma: no cover
129
+ if status is httpx.codes.OK:
130
+ return ScheduleResponse(**res.json()), metadata
131
+ assert_never(status)
132
+
133
+ def allow(
134
+ self,
135
+ allow_token: str,
136
+ pid: int,
137
+ ):
138
+ res = self.client.post('/allow', params={
139
+ 'allowToken': allow_token,
140
+ 'pid': pid,
141
+ })
142
+ return httpx.codes(res.status_code)
143
+
144
+ def release(
145
+ self,
146
+ allow_token: str,
147
+ fail: bool = False,
148
+ ) -> httpx.codes:
149
+ res = self.client.post('/release', params={
150
+ 'allowToken': allow_token,
151
+ 'fail': fail,
152
+ })
153
+ return httpx.codes(res.status_code)
154
+
155
+ def get_queue_size(self) -> float:
156
+ res = self.client.get('/queue-size')
157
+ assert res.status_code == 200, res.status_code
158
+ size = res.json()
159
+ return size
spaces/zero/client.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import time
7
+ import warnings
8
+ from datetime import timedelta
9
+ from typing import Any
10
+
11
+ import gradio as gr
12
+ import httpx
13
+ from packaging import version
14
+ from typing_extensions import assert_never
15
+
16
+ from .. import utils
17
+ from ..config import Config
18
+ from .api import APIClient
19
+ from .api import AuthLevel
20
+ from .api import QuotaInfos
21
+ from .api import ScheduleResponse
22
+ from .gradio import info
23
+ from .gradio import error
24
+ from .gradio import get_event
25
+ from .gradio import supports_auth
26
+
27
+
28
+ TOKEN_HEADER = 'X-IP-Token'
29
+ DEFAULT_SCHEDULE_DURATION = 60
30
+
31
+ UNUSED_MESSAGE = "GPU device not used"
32
+ NO_GPU_MESSAGE_REGULAR = "No GPU was available"
33
+ NO_GPU_MESSAGE_INQUEUE = "No GPU was available after 60s"
34
+ EXAMPLES_RETRY_MESSAGE = "Try re-running outside of examples if it happened after clicking one"
35
+
36
+ SIGNUP_ON_HF_TXT = "Create a free account"
37
+ SIGNUP_ON_HF_URL = "https://huggingface.co/join"
38
+ SUBSCRIBE_TO_PRO_TXT = "Subscribe to Pro"
39
+ SUBSCRIBE_TO_PRO_URL = "https://huggingface.co/settings/billing/subscription"
40
+
41
+
42
+ def api_client():
43
+ assert Config.zero_device_api_url is not None
44
+ httpx_client = httpx.Client(base_url=Config.zero_device_api_url, timeout=60, verify=False)
45
+ return APIClient(httpx_client)
46
+
47
+
48
+ def startup_report():
49
+ retries, max_retries = 0, 2
50
+ client = api_client()
51
+ while (status := client.startup_report()) is httpx.codes.NOT_FOUND: # pragma: no cover
52
+ time.sleep(1)
53
+ if (retries := retries + 1) > max_retries:
54
+ raise RuntimeError("Error while initializing ZeroGPU: NotFound")
55
+ if status is not httpx.codes.OK: # pragma: no cover
56
+ raise RuntimeError("Error while initializing ZeroGPU: Unknown")
57
+
58
+
59
+ def html_string(html_contents: str, text_contents: str): # pragma: no cover
60
+ class HTMLString(str):
61
+ def __str__(self):
62
+ return text_contents
63
+ return HTMLString(html_contents)
64
+
65
+
66
+ def _toast_action(
67
+ auth: AuthLevel | None,
68
+ supports_html: bool,
69
+ pro_message: str,
70
+ unlogged_desc: str,
71
+ logged_desc: str,
72
+ ending: str,
73
+ ) -> tuple[str, str]: # pragma: no cover
74
+ if not supports_auth() or auth == 'pro':
75
+ return pro_message, pro_message
76
+ html = ""
77
+ link = SIGNUP_ON_HF_URL if auth is None else SUBSCRIBE_TO_PRO_URL
78
+ text = SIGNUP_ON_HF_TXT if auth is None else SUBSCRIBE_TO_PRO_TXT
79
+ desc = unlogged_desc if auth is None else logged_desc
80
+ desc += f" {ending}."
81
+ style = ";".join([
82
+ "white-space: nowrap",
83
+ "text-underline-offset: 2px",
84
+ "color: var(--body-text-color)",
85
+ ])
86
+ if supports_html:
87
+ html += f'<a style="{style}" href="{link}">'
88
+ html += text
89
+ if supports_html:
90
+ html += '</a>'
91
+ html += f" {desc}"
92
+ markdown = f'[{text}]({link}) {desc}'
93
+ return html, markdown
94
+
95
+
96
+ def schedule(
97
+ task_id: int,
98
+ request: gr.Request | None = None,
99
+ duration: timedelta | None = None,
100
+ _first_attempt: bool = True,
101
+ ) -> ScheduleResponse:
102
+
103
+ if not (gradio_version := version.parse(gr.__version__)).major >= 4: # pragma: no cover
104
+ raise RuntimeError("ZeroGPU is only compatible with Gradio 4+")
105
+
106
+ GRADIO_HTML_TOASTS = gradio_version >= version.Version('4.39')
107
+ GRADIO_HANDSHAKE = gradio_version >= version.Version('5.16.1')
108
+
109
+ token, payload = _get_token_and_payload(request)
110
+ if token is not None and (token_error := payload.get('error')):
111
+ message = f"Falling back to IP-based quotas ({token_error})"
112
+ info("ZeroGPU client warning", message, level='warning')
113
+
114
+ res, meta = api_client().schedule(
115
+ cgroup_path=utils.self_cgroup_device_path(),
116
+ task_id=task_id,
117
+ token=token,
118
+ token_version=2 if GRADIO_HANDSHAKE else 1,
119
+ duration_seconds=duration.seconds if duration is not None else None,
120
+ )
121
+
122
+ auth = meta.auth
123
+
124
+ if isinstance(res, ScheduleResponse):
125
+ return res
126
+
127
+ if isinstance(res, QuotaInfos): # pragma: no cover
128
+ requested = duration.seconds if duration is not None else DEFAULT_SCHEDULE_DURATION
129
+ if res.wait < timedelta(0):
130
+ message = (
131
+ f"The requested GPU duration ({requested}s) "
132
+ f"is larger than the maximum allowed"
133
+ )
134
+ raise error("ZeroGPU illegal duration", message)
135
+ elif token is None:
136
+ message = (
137
+ f"Space app has reached its GPU limit. "
138
+ f"{EXAMPLES_RETRY_MESSAGE}"
139
+ )
140
+ raise error("ZeroGPU quota exceeded", message)
141
+ else:
142
+ if payload.get('user') is None and res.wait == 0:
143
+ message = "You have exceeded your runs limit."
144
+ else:
145
+ gpu = "Pro GPU" if auth == 'pro' else ("free GPU" if auth == 'regular' else "GPU")
146
+ message = (
147
+ f"You have exceeded your {gpu} quota "
148
+ f"({requested}s requested vs. {res.left}s left). "
149
+ f"Try again in {res.wait}"
150
+ )
151
+ raise error("ZeroGPU quota exceeded", message)
152
+
153
+ if not isinstance(res, httpx.codes): # pragma: no cover
154
+ if meta.queuing_reason in ('node', None):
155
+ info("ZeroGPU queue", "Waiting for a GPU to become available")
156
+ elif meta.queuing_reason == 'concurrency':
157
+ info("ZeroGPU queue", "Waiting for a GPU slot on this Space")
158
+ else:
159
+ assert_never(meta.queuing_reason)
160
+ # TODO: Sign-up message if not authenticated (after some time ?)
161
+ connection_event = get_event()
162
+ if connection_event is None and request is not None:
163
+ warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
164
+ while True:
165
+ try:
166
+ event = next(res)
167
+ except StopIteration:
168
+ raise RuntimeError("Unexpected end of stream")
169
+ except httpx.RemoteProtocolError:
170
+ if not _first_attempt:
171
+ raise RuntimeError("Error while re-trying after queue disconnect")
172
+ return schedule(task_id, request, duration, _first_attempt=False)
173
+ if event.event == 'ping':
174
+ if connection_event is not None and not connection_event.alive:
175
+ res.close()
176
+ raise RuntimeError("Connection closed by visitor while queueing")
177
+ continue
178
+ if event.event == 'failed':
179
+ if token is None:
180
+ message = f"{NO_GPU_MESSAGE_INQUEUE}. {EXAMPLES_RETRY_MESSAGE}"
181
+ raise error("ZeroGPU quota exceeded", message)
182
+ details_html, details_markdown = _toast_action(
183
+ auth=auth,
184
+ supports_html=GRADIO_HTML_TOASTS,
185
+ pro_message="Retry later",
186
+ unlogged_desc="to get a higher",
187
+ logged_desc="to get the highest",
188
+ ending="priority in ZeroGPU queues",
189
+ )
190
+ message_html = f"{NO_GPU_MESSAGE_INQUEUE}. {details_html}"
191
+ message_text = f"{NO_GPU_MESSAGE_INQUEUE} {details_markdown}"
192
+ message = html_string(message_html, message_text)
193
+ raise error("ZeroGPU queue timeout", message, html=True)
194
+ if event.event == 'succeeded':
195
+ assert event.data is not None
196
+ if connection_event is not None and not connection_event.alive:
197
+ release(event.data.allowToken)
198
+ raise RuntimeError("Connection closed by visitor on queue success")
199
+ info("ZeroGPU queue", "Successfully acquired a GPU", level='success')
200
+ return event.data
201
+
202
+ if res is httpx.codes.SERVICE_UNAVAILABLE:
203
+ raise error("ZeroGPU client error", NO_GPU_MESSAGE_REGULAR)
204
+
205
+ if res is httpx.codes.UNAUTHORIZED: # pragma: no cover
206
+ raise error("ZeroGPU client error", "Expired ZeroGPU proxy token")
207
+
208
+ # TODO: Find a way to log 'detail' response field
209
+ raise RuntimeError(f"ZeroGPU API /schedule error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
210
+
211
+
212
+ def allow(allow_token: str) -> None:
213
+ pid = os.getpid()
214
+ assert pid != 1, "Allowing PID 1 on ZeroGPU will end up killing your Space"
215
+ assert api_client().allow(allow_token=allow_token, pid=pid) is httpx.codes.OK
216
+
217
+
218
+ def release(
219
+ allow_token: str, *,
220
+ fail: bool = False,
221
+ allow_404: bool = False,
222
+ ) -> None:
223
+
224
+ res = api_client().release(
225
+ allow_token=allow_token,
226
+ fail=fail,
227
+ )
228
+
229
+ if res is httpx.codes.NO_CONTENT: # pragma: no cover
230
+ try:
231
+ info("ZeroGPU client warning", UNUSED_MESSAGE, level='warning')
232
+ except AttributeError:
233
+ pass
234
+ warnings.warn(UNUSED_MESSAGE, RuntimeWarning)
235
+ return None
236
+
237
+ if res is httpx.codes.NOT_FOUND:
238
+ if not allow_404:
239
+ warnings.warn("ZeroGPU API /release warning: 404 Not Found")
240
+ return None
241
+
242
+ if httpx.codes.is_success(res):
243
+ return None
244
+
245
+ # TODO: Find a way to log 'detail' response field
246
+ # TODO: Only raise in dev environment. Simply warn in production ?
247
+ raise RuntimeError(f"ZeroGPU API /release error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
248
+
249
+
250
+ def _get_token(request: gr.Request | None) -> str | None:
251
+
252
+ if request is None:
253
+ return None
254
+
255
+ headers = getattr(request, 'headers', None)
256
+ if headers is None or not hasattr(headers, '__dict__'):
257
+ raise error("ZeroGPU client error", "Internal Gradio error")
258
+
259
+ # Compatibility trick
260
+ if not hasattr(headers, 'get'):
261
+ headers = headers.__dict__ # pragma: no cover
262
+
263
+ return headers.get(TOKEN_HEADER.lower())
264
+
265
+
266
+ def _get_token_and_payload(request: gr.Request | None) -> tuple[str | None, dict[str, Any]]:
267
+ if (token := _get_token(request)) is None:
268
+ return None, {}
269
+ try:
270
+ payload = utils.jwt_payload(token)
271
+ except Exception: # pragma: no cover
272
+ warnings.warn("Error while decoding X-IP-Token JWT")
273
+ return token, {}
274
+ return token, payload
spaces/zero/decorator.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import inspect
6
+ import sys
7
+ import warnings
8
+ from datetime import timedelta
9
+ from functools import partial
10
+ from typing import Callable
11
+ from typing import TypeVar
12
+ from typing import overload
13
+ from typing_extensions import ParamSpec
14
+ from typing_extensions import Unpack
15
+
16
+ from ..config import Config
17
+ from .types import DynamicDuration
18
+ from .types import EmptyKwargs
19
+
20
+
21
+ P = ParamSpec('P')
22
+ R = TypeVar('R')
23
+
24
+
25
+ decorated_cache: dict[Callable, Callable] = {}
26
+
27
+
28
+ @overload
29
+ def GPU(
30
+ task: None = None, *,
31
+ duration: DynamicDuration[P] = None,
32
+ ) -> Callable[[Callable[P, R]], Callable[P, R]]:
33
+ ...
34
+ @overload
35
+ def GPU(
36
+ task: Callable[P, R], *,
37
+ duration: DynamicDuration[P] = None,
38
+ ) -> Callable[P, R]:
39
+ ...
40
+ def GPU(
41
+ task: Callable[P, R] | None = None, *,
42
+ duration: DynamicDuration[P] = None,
43
+ **kwargs: Unpack[EmptyKwargs],
44
+ ) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R]:
45
+ """
46
+ ZeroGPU decorator
47
+
48
+ Basic usage:
49
+ ```
50
+ @spaces.GPU
51
+ def fn(...):
52
+ # CUDA is available here
53
+ pass
54
+ ```
55
+
56
+ With custom duration:
57
+ ```
58
+ @spaces.GPU(duration=45) # Expressed in seconds
59
+ def fn(...):
60
+ # CUDA is available here
61
+ pass
62
+ ```
63
+
64
+ Args:
65
+ task (`Callable | None`): Python function that requires CUDA
66
+ duration (`int | datetime.timedelta`): Estimated duration in seconds or `datetime.timedelta`
67
+
68
+ Returns:
69
+ `Callable`: GPU-ready function
70
+ """
71
+ if "enable_queue" in kwargs:
72
+ warnings.warn("`enable_queue` parameter is now ignored and always set to `True`")
73
+ if task is None:
74
+ return partial(_GPU, duration=duration)
75
+ return _GPU(task, duration)
76
+
77
+
78
+ def _GPU(
79
+ task: Callable[P, R],
80
+ duration: DynamicDuration[P],
81
+ ) -> Callable[P, R]:
82
+
83
+ if not Config.zero_gpu:
84
+ return task
85
+
86
+ from . import client
87
+ from .wrappers import regular_function_wrapper
88
+ from .wrappers import generator_function_wrapper
89
+
90
+ if sys.version_info.minor < 9: # pragma: no cover
91
+ raise RuntimeError("Actually using @spaces.GPU on a ZeroGPU Space requires Python 3.9+")
92
+
93
+ if task in decorated_cache:
94
+ # TODO: Assert same duration ?
95
+ return decorated_cache[task] # type: ignore
96
+
97
+ if inspect.iscoroutinefunction(task):
98
+ raise NotImplementedError
99
+
100
+ if inspect.isgeneratorfunction(task):
101
+ decorated = generator_function_wrapper(task, duration)
102
+ else:
103
+ decorated = regular_function_wrapper(task, duration)
104
+
105
+ setattr(decorated, 'zerogpu', None)
106
+
107
+ client.startup_report()
108
+ decorated_cache.update({
109
+ task: decorated,
110
+ decorated: decorated,
111
+ })
112
+
113
+ return decorated # type: ignore
spaces/zero/gradio.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import inspect
6
+ from functools import wraps
7
+ from packaging import version
8
+ from typing import Any
9
+ from typing import Callable
10
+ from typing import Literal
11
+ from typing import NamedTuple
12
+ from typing import TYPE_CHECKING
13
+ import warnings
14
+
15
+ import gradio as gr
16
+ from gradio.context import Context
17
+ from gradio.context import LocalContext
18
+ from gradio.helpers import Progress
19
+ from gradio.helpers import TrackedIterable
20
+ from gradio.queueing import Queue
21
+ from typing_extensions import ParamSpec
22
+ from typing_extensions import TypeAlias
23
+
24
+ from ..utils import SimpleQueue
25
+ from .types import GeneratorResQueueResult
26
+ from .types import GradioQueueEvent
27
+ from .types import RegularResQueueResult
28
+
29
+
30
+ QUEUE_RPC_METHODS = [
31
+ "set_progress",
32
+ "log_message",
33
+ ]
34
+
35
+
36
+ try:
37
+ Success = gr.Success # pyright: ignore[reportAttributeAccessIssue] (Gradio<5.10)
38
+ except AttributeError: # pragma: no cover
39
+ Success = gr.Info
40
+
41
+ Level: TypeAlias = "Literal['success', 'info', 'warning']"
42
+
43
+ def modal(level: Level):
44
+ if level == 'info':
45
+ return gr.Info
46
+ if level == 'success':
47
+ return Success
48
+ if level == 'warning':
49
+ return gr.Warning
50
+
51
+
52
+ class GradioPartialContext(NamedTuple):
53
+ event_id: str | None
54
+ in_event_listener: bool
55
+ progress: Progress | None
56
+
57
+ @staticmethod
58
+ def get():
59
+ TrackedIterable.__reduce__ = tracked_iterable__reduce__
60
+ return GradioPartialContext(
61
+ event_id=LocalContext.event_id.get(),
62
+ in_event_listener=LocalContext.in_event_listener.get(),
63
+ progress=LocalContext.progress.get(),
64
+ )
65
+
66
+ @staticmethod
67
+ def apply(context: 'GradioPartialContext'):
68
+ LocalContext.event_id.set(context.event_id)
69
+ LocalContext.in_event_listener.set(context.in_event_listener)
70
+ LocalContext.progress.set(context.progress)
71
+
72
+
73
+ def get_queue_instance():
74
+ blocks = LocalContext.blocks.get()
75
+ if blocks is None: # pragma: no cover
76
+ return None
77
+ return blocks._queue
78
+
79
+
80
+ def get_event():
81
+ queue = get_queue_instance()
82
+ event_id = LocalContext.event_id.get()
83
+ if queue is None:
84
+ return None
85
+ if event_id is None: # pragma: no cover
86
+ return None
87
+ for job in queue.active_jobs:
88
+ if job is None: # pragma: no cover
89
+ continue
90
+ for event in job:
91
+ if event._id == event_id:
92
+ return event
93
+
94
+
95
+ def get_server_port() -> int | None:
96
+ from_request_context = True
97
+ if (blocks := LocalContext.blocks.get()) is None: # Request
98
+ from_request_context = False
99
+ if (blocks := Context.root_block) is None: # Caching
100
+ return None
101
+ if (server := getattr(blocks, 'server', None)) is None: # pragma: no cover (Gradio 4)
102
+ if from_request_context:
103
+ warnings.warn("Gradio: No blocks.server inside a request") # pragma: no cover
104
+ return -1
105
+ if TYPE_CHECKING:
106
+ assert (server := blocks.server)
107
+ return server.config.port
108
+
109
+
110
+ def try_process_queue_event(method_name: str, *args, **kwargs):
111
+ queue = get_queue_instance()
112
+ if queue is None: # pragma: no cover
113
+ warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
114
+ return
115
+ method = getattr(queue, method_name, None)
116
+ assert callable(method)
117
+ method(*args, **kwargs)
118
+
119
+
120
+ def patch_gradio_queue(
121
+ res_queue: SimpleQueue[RegularResQueueResult | None] | SimpleQueue[GeneratorResQueueResult | None],
122
+ ):
123
+
124
+ def rpc_method(method_name: str):
125
+ def method(*args, **kwargs):
126
+ if args and isinstance(args[0], Queue):
127
+ args = args[1:] # drop `self`
128
+ res_queue.put(GradioQueueEvent(method_name, args, kwargs))
129
+ return method
130
+
131
+ for method_name in QUEUE_RPC_METHODS:
132
+ if (method := getattr(Queue, method_name, None)) is None: # pragma: no cover
133
+ warnings.warn(f"ZeroGPU: Gradio Queue has no {method_name} attribute")
134
+ continue
135
+ if not callable(method): # pragma: no cover
136
+ warnings.warn(f"ZeroGPU: Gradio Queue {method_name} is not callable")
137
+ continue
138
+ setattr(Queue, method_name, rpc_method(method_name))
139
+
140
+ TrackedIterable.__reduce__ = tracked_iterable__reduce__
141
+
142
+
143
+ def tracked_iterable__reduce__(self):
144
+ res: tuple = super(TrackedIterable, self).__reduce__() # type: ignore
145
+ cls, base, state, *_ = res
146
+ return cls, base,{**state, **{
147
+ 'iterable': None,
148
+ '_tqdm': None,
149
+ }}
150
+
151
+
152
+ def supports_auth():
153
+ return version.parse(gr.__version__) >= version.Version('4.27.0')
154
+
155
+
156
+ Param = ParamSpec('Param')
157
+
158
+ def one_launch(task: Callable[Param, None], *task_args: Param.args, **task_kwargs: Param.kwargs):
159
+ _launch = gr.Blocks.launch
160
+ @wraps(gr.Blocks.launch)
161
+ def launch(*args, **kwargs):
162
+ task(*task_args, **task_kwargs)
163
+ gr.Blocks.launch = _launch
164
+ return gr.Blocks.launch(*args, **kwargs)
165
+ gr.Blocks.launch = launch
166
+
167
+
168
+ class HTMLError(gr.Error):
169
+ def __str__(self): # pragma: no cover
170
+ return self.message
171
+
172
+
173
+ def error(title: str, message: str, html: bool = False):
174
+ params = inspect.signature(gr.Error).parameters
175
+ kwargs: dict[str, Any] = {}
176
+ if 'title' in params:
177
+ kwargs = {**kwargs, 'title': title}
178
+ if 'print_exception' in params:
179
+ kwargs = {**kwargs, 'print_exception': False}
180
+ error_cls = HTMLError if html else gr.Error
181
+ return error_cls(message, **kwargs)
182
+
183
+
184
+ def info(title: str, message: str, level: Level = 'info'):
185
+ params = inspect.signature(gr.Info).parameters
186
+ kwargs: dict[str, Any] = {}
187
+ if 'title' in params:
188
+ kwargs = {**kwargs, 'title': title}
189
+ info_cls = modal(level)
190
+ return info_cls(message, **kwargs)
spaces/zero/torch/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from ...config import Config
5
+
6
+
7
+ try:
8
+
9
+ import torch
10
+
11
+ except ImportError:
12
+
13
+ _patch = lambda *args, **kwargs: None
14
+ _unpatch = lambda *args, **kwargs: None
15
+ _pack = lambda *args, **kwargs: None
16
+ _init = lambda *args, **kwargs: None
17
+ _size = lambda *args, **kwargs: 0
18
+ _move = lambda *args, **kwargs: None
19
+ _is_in_bad_fork = lambda *args, **kwargs: False
20
+
21
+ else:
22
+
23
+ if Config.zero_gpu_v2:
24
+ from . import patching as _patching
25
+ else: # pragma: no cover
26
+ from . import patching_legacy as _patching
27
+
28
+ _patch = _patching.patch
29
+ _unpatch = _patching.unpatch
30
+ _pack = _patching.pack
31
+ _init = _patching.init
32
+ _size = _patching.size
33
+ _move = _patching.move
34
+ _is_in_bad_fork = _patching.is_in_bad_fork
35
+
36
+ patch = _patch
37
+ unpatch = _unpatch
38
+ pack = _pack
39
+ init = _init
40
+ size = _size
41
+ move = _move
42
+ is_in_bad_fork = _is_in_bad_fork
spaces/zero/torch/bitsandbytes.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ # pyright: reportPrivateImportUsage=false
4
+
5
+ from __future__ import annotations
6
+
7
+ import importlib
8
+ from contextlib import contextmanager
9
+ from contextlib import nullcontext
10
+ from importlib import metadata
11
+ from types import ModuleType
12
+ from typing import TYPE_CHECKING
13
+ from typing import Tuple
14
+
15
+ import torch
16
+ from packaging import version
17
+
18
+ if TYPE_CHECKING:
19
+ import torch as Torch
20
+
21
+
22
+ @contextmanager
23
+ def cuda_unavailable(torch: ModuleType): # pragma: no cover
24
+ _is_available = torch.cuda.is_available
25
+ torch.cuda.is_available = lambda: False
26
+ yield
27
+ torch.cuda.is_available = _is_available
28
+
29
+
30
+ def maybe_import_bitsandbytes():
31
+ try:
32
+ import torch
33
+ except ImportError: # pragma: no cover
34
+ return None
35
+ try:
36
+ bnb_version = version.parse(metadata.version('bitsandbytes'))
37
+ except ImportError: # pragma: no cover
38
+ return None
39
+ if bnb_version < version.parse('0.40.0'): # pragma: no cover
40
+ raise RuntimeError(f"ZeroGPU requires bitsandbytes >= 0.40.0 (installed: {bnb_version})")
41
+ if bnb_version < version.parse('0.43.1'): # pragma: no cover
42
+ context = lambda: cuda_unavailable(torch)
43
+ else:
44
+ context = lambda: nullcontext()
45
+ with (ctx := context()):
46
+ try:
47
+ import bitsandbytes
48
+ except ImportError:
49
+ return None
50
+ if not isinstance(ctx, nullcontext): # pragma: no cover
51
+ print("↑ Those bitsandbytes warnings are expected on ZeroGPU ↑")
52
+ return context
53
+
54
+
55
+ if (import_context := maybe_import_bitsandbytes()):
56
+
57
+ from torch.utils.weak import WeakTensorKeyDictionary
58
+
59
+ with (import_ctx := import_context()):
60
+ CUDASetup = None
61
+ if not isinstance(import_ctx, nullcontext): # pragma: no cover
62
+ from bitsandbytes.cuda_setup.main import CUDASetup # pyright: ignore [reportMissingImports]
63
+ from bitsandbytes import cextension
64
+ from bitsandbytes import functional
65
+ from bitsandbytes.nn import Int8Params
66
+ from bitsandbytes.nn import Params4bit
67
+
68
+ _param_to_8bit = Int8Params.to # type: ignore
69
+ _param_cuda_8bit = Int8Params.cuda
70
+ _param_to_4bit = Params4bit.to # type: ignore
71
+ _param_cuda_4bit = Params4bit.cuda
72
+
73
+ TensorToArgs = Tuple[torch.device, torch.dtype, bool, torch.memory_format]
74
+
75
+ to_ops_8bit: dict[Int8Params, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
76
+ to_ops_4bit: dict[Params4bit, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
77
+
78
+ def _to_op_register_8bit(self: Int8Params, *args, **kwargs):
79
+ parsed = torch._C._nn._parse_to(*args, **kwargs)
80
+ device, *_ = parsed
81
+ if not isinstance(device, torch.device): # pragma: no cover
82
+ return _param_to_8bit(self, *args, **kwargs)
83
+ if device.type != 'cuda':
84
+ return _param_to_8bit(self, *args, **kwargs)
85
+ to_ops_8bit[self] = parsed
86
+ return self
87
+
88
+ def _to_op_register_4bit(self: Params4bit, *args, **kwargs):
89
+ parsed = torch._C._nn._parse_to(*args, **kwargs)
90
+ device, *_ = parsed
91
+ if not isinstance(device, torch.device): # pragma: no cover
92
+ return _param_to_4bit(self, *args, **kwargs)
93
+ if device.type != 'cuda':
94
+ return _param_to_4bit(self, *args, **kwargs)
95
+ to_ops_4bit[self] = parsed
96
+ return self
97
+
98
+ def _cuda_op_arg_check(device: Torch.device | int | str | None) -> bool:
99
+ if device is None: # pragma: no cover
100
+ return True
101
+ if isinstance(device, int):
102
+ return True
103
+ if isinstance(device, str): # pragma: no cover
104
+ device = torch.device(device)
105
+ return device.type == 'cuda' # pragma: no cover
106
+
107
+ def _cuda_op_register_8bit(self: Int8Params, device: Torch.device | int | str | None = None, **kwargs):
108
+ if not _cuda_op_arg_check(device): # pragma: no cover
109
+ # Let PyTorch handle the fail
110
+ return _param_cuda_8bit(self, device, **kwargs)
111
+ to_ops_8bit[self] = None
112
+ return self
113
+
114
+ def _cuda_op_register_4bit(self: Params4bit, device: Torch.device | int | str | None = None, **kwargs):
115
+ if not _cuda_op_arg_check(device): # pragma: no cover
116
+ # Let PyTorch handle the fail
117
+ return _param_cuda_4bit(self, device, **kwargs)
118
+ to_ops_4bit[self] = None
119
+ return self
120
+
121
+ def _patch():
122
+ Int8Params.to = _to_op_register_8bit # type: ignore
123
+ Int8Params.cuda = _cuda_op_register_8bit # type: ignore
124
+ Params4bit.to = _to_op_register_4bit # type: ignore
125
+ Params4bit.cuda = _cuda_op_register_4bit # type: ignore
126
+
127
+ def _unpatch():
128
+ Int8Params.to = _param_to_8bit # type: ignore
129
+ Int8Params.cuda = _param_cuda_8bit
130
+ Params4bit.to = _param_to_4bit # type: ignore
131
+ Params4bit.cuda = _param_cuda_4bit
132
+
133
+ def _move():
134
+ if CUDASetup is not None: # pragma: no cover
135
+ CUDASetup._instance = None
136
+ importlib.reload(cextension)
137
+ functional.lib = cextension.lib
138
+ for op in to_ops_8bit.items():
139
+ tensor, parsed_args = op
140
+ if parsed_args:
141
+ _, dtype, _, memory_format = parsed_args
142
+ else:
143
+ dtype, memory_format = None, None
144
+ tensor.data = _param_to_8bit(tensor,
145
+ device='cuda',
146
+ dtype=dtype,
147
+ memory_format=memory_format,
148
+ ) # type: ignore
149
+ for op in to_ops_4bit.items():
150
+ tensor, parsed_args = op
151
+ if parsed_args:
152
+ _, dtype, _, memory_format = parsed_args
153
+ else:
154
+ dtype, memory_format = None, None
155
+ tensor.data = _param_to_4bit(tensor,
156
+ device='cuda',
157
+ dtype=dtype,
158
+ memory_format=memory_format,
159
+ ) # type: ignore
160
+
161
+ else:
162
+
163
+ _patch = lambda: None
164
+ _unpatch = lambda: None
165
+ _move = lambda: None
166
+
167
+
168
+ patch = _patch
169
+ unpatch = _unpatch
170
+ move = _move
spaces/zero/torch/cudart.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from .static import CUDA_MEM_GET_INFO
5
+
6
+
7
+ def cudaMemGetInfo(device: int, /):
8
+ return CUDA_MEM_GET_INFO
spaces/zero/torch/packing.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import time
6
+
7
+ import ctypes
8
+ import os
9
+ from concurrent.futures import as_completed
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from contextvars import copy_context
12
+ from dataclasses import dataclass
13
+ from queue import Queue
14
+ from typing import Callable
15
+
16
+ from ...utils import debug
17
+
18
+ import torch
19
+ from typing_extensions import TypeAlias
20
+
21
+
22
+ PAGE_SIZE = 4096
23
+ TOTAL_MEMORY = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
24
+ VM_MAX_SIZE = min(2**38, TOTAL_MEMORY // 2)
25
+
26
+ BUFFER_SIZE = 64 * 2**20
27
+ BUFFER_COUNT = 2
28
+
29
+
30
+ TensorWithSizes: TypeAlias = 'tuple[torch.Tensor, int, int]'
31
+
32
+ @dataclass
33
+ class ZeroGPUTensorPack:
34
+ base_dir: str
35
+ batches: list[list[TensorWithSizes]]
36
+ big_tensors: list[TensorWithSizes]
37
+ fakes: dict[torch.Tensor, list[torch.Tensor]]
38
+ total_size: int
39
+ def path(self):
40
+ return f'{self.base_dir}/{id(self)}'
41
+ def __del__(self):
42
+ try:
43
+ os.remove(self.path())
44
+ except FileNotFoundError: # pragma: no cover
45
+ pass
46
+
47
+
48
+ def write(fd: int, tensor: torch.Tensor):
49
+ clone = torch.empty_like(tensor)
50
+ size = clone.untyped_storage().size() # pyright: ignore [reportAttributeAccessIssue]
51
+ buffer = torch.UntypedStorage(VM_MAX_SIZE)
52
+ buffer_ptr = buffer.data_ptr()
53
+ offset = -buffer_ptr % PAGE_SIZE
54
+ padding = -size % PAGE_SIZE
55
+ clone.set_(buffer[offset:offset+size], 0, clone.shape, clone.stride()) # pyright: ignore [reportArgumentType]
56
+ clone.copy_(tensor)
57
+ mv = memoryview((ctypes.c_char * (size+padding)).from_address(buffer_ptr+offset))
58
+ written_bytes = 0
59
+ while written_bytes < size:
60
+ written_bytes += os.write(fd, mv[written_bytes:])
61
+
62
+
63
+ def pack_tensors(
64
+ tensors: set[torch.Tensor],
65
+ fakes: dict[torch.Tensor, list[torch.Tensor]],
66
+ offload_dir: str,
67
+ callback: Callable[[int]] | None = None,
68
+ ):
69
+
70
+ callback = (lambda bytes: None) if callback is None else callback
71
+
72
+ batches: list[list[TensorWithSizes]] = []
73
+ big_tensors: list[TensorWithSizes] = []
74
+
75
+ tensors_with_sizes: list[tuple[torch.Tensor, int, int]] = []
76
+ for tensor in tensors:
77
+ size = tensor.numel() * tensor.element_size()
78
+ aligned_size = size + (-size % PAGE_SIZE)
79
+ tensors_with_sizes += [(tensor, size, aligned_size)]
80
+
81
+ current_batch, current_size = [], 0
82
+ for (tensor, size, aligned_size) in sorted(tensors_with_sizes, key=lambda item: item[2]):
83
+ if aligned_size > BUFFER_SIZE:
84
+ big_tensors += [(tensor, size, aligned_size)]
85
+ continue
86
+ current_size += aligned_size
87
+ if current_size > BUFFER_SIZE:
88
+ batches += [current_batch]
89
+ current_batch, current_size = [(tensor, size, aligned_size)], aligned_size
90
+ else:
91
+ current_batch += [(tensor, size, aligned_size)]
92
+
93
+ if current_batch:
94
+ batches += [current_batch]
95
+
96
+ get_meta = {tensor: torch.empty_like(tensor) for tensor in tensors}
97
+ batches_meta = [[(get_meta[tensor], size, asize) for tensor, size, asize in batch] for batch in batches]
98
+ big_tensors_meta = [(get_meta[tensor], size, asize) for tensor, size, asize in big_tensors]
99
+ fakes_meta = {get_meta[tensor]: fake_list for tensor, fake_list in fakes.items()}
100
+
101
+ pack = ZeroGPUTensorPack(
102
+ base_dir=offload_dir,
103
+ batches=batches_meta,
104
+ big_tensors=big_tensors_meta,
105
+ fakes=fakes_meta,
106
+ total_size=sum([size for _, size, _ in tensors_with_sizes]),
107
+ )
108
+
109
+ fd = os.open(pack.path(), os.O_CREAT | os.O_WRONLY | os.O_DIRECT)
110
+ try:
111
+ total_asize = sum([aligned_size for batch in batches for *_, aligned_size in batch])
112
+ total_asize += sum([aligned_size for *_, aligned_size in big_tensors])
113
+ if total_asize > 0:
114
+ os.posix_fallocate(fd, 0, total_asize)
115
+ for batch in batches:
116
+ for tensor, size, _ in batch:
117
+ write(fd, tensor)
118
+ callback(size)
119
+ for tensor, size, _ in big_tensors:
120
+ write(fd, tensor)
121
+ callback(size)
122
+ return pack
123
+ finally:
124
+ os.close(fd)
125
+
126
+
127
+ def pack_to_cuda(pack: ZeroGPUTensorPack, callback: Callable[[int]] | None = None):
128
+
129
+ callback = (lambda bytes: None) if callback is None else callback
130
+
131
+ free_buffers: Queue[torch.Tensor] = Queue()
132
+ read_buffers: Queue[torch.Tensor] = Queue()
133
+
134
+ for _ in range(BUFFER_COUNT):
135
+ free_buffers.put(torch.ByteTensor(BUFFER_SIZE).pin_memory())
136
+
137
+ def read(fd: int, buffer: torch.Tensor, size: int):
138
+ mv = memoryview((ctypes.c_char * size).from_address(buffer.data_ptr()))
139
+ read_bytes = 0
140
+ while read_bytes < size:
141
+ read_bytes += os.readv(fd, [mv[read_bytes:]])
142
+
143
+ def disk_to_pin(fd: int):
144
+ for batch in pack.batches:
145
+ buffer = free_buffers.get()
146
+ batch_size = sum([aligned_size for *_, aligned_size in batch])
147
+ read(fd, buffer, batch_size)
148
+ read_buffers.put(buffer)
149
+ for *_, aligned_size in pack.big_tensors:
150
+ read_bytes = 0
151
+ while read_bytes < aligned_size:
152
+ buffer = free_buffers.get()
153
+ read_size = min(BUFFER_SIZE, aligned_size - read_bytes)
154
+ read(fd, buffer, read_size)
155
+ read_buffers.put(buffer)
156
+ read_bytes += read_size
157
+
158
+ def pin_to_cuda():
159
+ total_duration_in_callback = 0
160
+ for batch in pack.batches:
161
+ buffer = read_buffers.get()
162
+ offset = 0
163
+ cuda_storages = []
164
+ for tensor, size, aligned_size in batch:
165
+ cuda_storages += [buffer[offset:offset+size].cuda(non_blocking=True)]
166
+ offset += aligned_size
167
+ torch.cuda.synchronize()
168
+ free_buffers.put(buffer)
169
+ batch_total_size = 0
170
+ for (tensor, size, _), cuda_storage in zip(batch, cuda_storages):
171
+ cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda')
172
+ cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride())
173
+ for fake in pack.fakes[tensor]:
174
+ fake.data = cuda_tensor
175
+ batch_total_size += size
176
+ t0 = time.perf_counter()
177
+ callback(batch_total_size)
178
+ total_duration_in_callback += time.perf_counter() - t0
179
+ for tensor, size, _ in pack.big_tensors:
180
+ cuda_storage = torch.empty(size, dtype=torch.uint8, device='cuda')
181
+ offset = 0
182
+ while offset < size:
183
+ buffer = read_buffers.get()
184
+ read_size = min(BUFFER_SIZE, size - offset)
185
+ cuda_storage[offset:offset+read_size] = buffer[:read_size]
186
+ offset += read_size
187
+ torch.cuda.synchronize() # Probably not needed
188
+ free_buffers.put(buffer)
189
+ t0 = time.perf_counter()
190
+ callback(read_size)
191
+ total_duration_in_callback += time.perf_counter() - t0
192
+ cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda')
193
+ cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride())
194
+ for fake in pack.fakes[tensor]:
195
+ fake.data = cuda_tensor
196
+
197
+ debug(f"{total_duration_in_callback=}")
198
+
199
+ with ThreadPoolExecutor(2) as e:
200
+ fd = os.open(pack.path(), os.O_RDONLY | os.O_DIRECT)
201
+ try:
202
+ futures = [
203
+ e.submit(copy_context().run, disk_to_pin, fd),
204
+ e.submit(copy_context().run, pin_to_cuda),
205
+ ]
206
+ for future in as_completed(futures):
207
+ future.result()
208
+ finally:
209
+ os.close(fd)
spaces/zero/torch/patching.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ # pyright: reportPrivateImportUsage=false
4
+
5
+ from __future__ import annotations
6
+
7
+ import gc
8
+ import multiprocessing
9
+ import os
10
+ from collections import defaultdict
11
+ from concurrent.futures import ProcessPoolExecutor
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ from contextlib import nullcontext
14
+ from contextvars import copy_context
15
+ from typing import Any
16
+ from typing import Callable
17
+
18
+ import torch
19
+ from torch.overrides import TorchFunctionMode
20
+ from torch.overrides import resolve_name
21
+ from torch.utils._python_dispatch import TorchDispatchMode
22
+ from torch.utils._pytree import tree_map_only
23
+ from torch.utils.weak import WeakTensorKeyDictionary
24
+
25
+ from ...config import Config
26
+ from ...utils import malloc_trim
27
+ from ..tqdm import tqdm
28
+ from . import cudart
29
+ from .packing import ZeroGPUTensorPack
30
+ from .packing import pack_tensors
31
+ from .packing import pack_to_cuda
32
+ from .static import *
33
+ from .types import AliasId
34
+
35
+
36
+ OPS_INPUTS_CHECK_NO_RETURN = (
37
+ torch.Tensor.equal,
38
+ )
39
+
40
+ OPS_INPUT_CHECK_SELF_RETURN = (
41
+ torch.Tensor.set_, # probably never dispatched
42
+ torch.ops.aten.set_.source_Tensor, # pyright: ignore [reportAttributeAccessIssue]
43
+ )
44
+
45
+ OFFLOADED_ERROR_MESSAGE = "Cannot apply function {} on disk-offloaded Tensor {}"
46
+
47
+ _tensor_make_subclass = torch.Tensor._make_subclass
48
+ _asarray = torch.asarray
49
+ _device = torch.device
50
+ _cuda_init = torch._C._cuda_init
51
+ _cuda_exchange_device = torch.cuda._exchange_device
52
+ _cuda_available = torch.cuda.is_available
53
+ _cuda_device_count = torch.cuda.device_count
54
+ _cuda_current_device = torch.cuda.current_device
55
+ _cuda_get_device_capability = torch.cuda.get_device_capability
56
+ _cuda_get_device_properties = torch.cuda.get_device_properties
57
+ _cuda_get_device_name = torch.cuda.get_device_name
58
+ _cuda_memory_stats_as_nested_dict = torch.cuda.memory.memory_stats_as_nested_dict
59
+ _cuda_cudart = torch.cuda.cudart
60
+
61
+ # PyTorch 2.3
62
+ _cuda_maybe_exchange_device = getattr(torch.cuda, '_maybe_exchange_device', None)
63
+
64
+
65
+ cuda_aliases: dict[torch.Tensor, torch.Tensor | None] = WeakTensorKeyDictionary() # pyright: ignore [reportAssignmentType]
66
+
67
+ tensor_packs: list[ZeroGPUTensorPack] = []
68
+
69
+ class ZeroGPUTensor(torch.Tensor):
70
+ pass
71
+
72
+ def empty_fake(tensor: torch.Tensor):
73
+ fake = torch.empty_like(tensor, requires_grad=tensor.requires_grad)
74
+ if fake.__class__ != tensor.__class__:
75
+ fake = _tensor_make_subclass(tensor.__class__, fake, require_grad=tensor.requires_grad) # pyright: ignore [reportArgumentType]
76
+ return fake
77
+
78
+ # Torch 2.5: https://github.com/pytorch/pytorch/issues/144152
79
+ def no_int_device(*args, **kwargs):
80
+ if len(args) and isinstance(index := args[0], int):
81
+ args = (f'cuda:{index}', *args[1:])
82
+ if isinstance(index := kwargs.get('device'), int):
83
+ kwargs['device'] = f'cuda:{index}'
84
+ return args, kwargs
85
+
86
+
87
+ class ZeroGPUFunctionMode(TorchFunctionMode):
88
+
89
+ def __torch_function__(self, func, types, args=(), kwargs: dict[str, Any] | None = None):
90
+
91
+ kwargs = {} if kwargs is None else kwargs
92
+
93
+ if func == torch._C._nn._parse_to:
94
+ args, kwargs = no_int_device(*args, **kwargs)
95
+ return func(*args, **kwargs)
96
+
97
+ # Redispatch: tensor.cuda() -> tensor.to(device='cuda')
98
+ if func == torch.Tensor.cuda or func == torch.Tensor.cpu:
99
+ memory_format = kwargs.get('memory_format')
100
+ return self.__torch_function__(torch.Tensor.to, types, (args[0],), {
101
+ 'device': 'cuda' if func == torch.Tensor.cuda else 'cpu',
102
+ **({'memory_format': memory_format} if memory_format is not None else {}),
103
+ })
104
+
105
+ # Redispatch: tensor.to('cuda') -> tensor.to(device='cuda')
106
+ if func == torch.Tensor.to and len(args) > 1:
107
+ parse_to_args, parse_to_kwargs = no_int_device(*args[1:], **kwargs)
108
+ device, dtype, _, memory_format = torch._C._nn._parse_to(*parse_to_args, **parse_to_kwargs) # pyright: ignore [reportCallIssue, reportArgumentType]
109
+ return self.__torch_function__(torch.Tensor.to, types, (args[0],), {
110
+ 'device': device,
111
+ 'dtype': dtype,
112
+ 'memory_format': memory_format,
113
+ })
114
+
115
+ if func == torch.Tensor.data.__set__: # pyright: ignore [reportAttributeAccessIssue]
116
+ self, target = args
117
+ if target in cuda_aliases:
118
+ if (target_original := cuda_aliases[target]) is None:
119
+ raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), target))
120
+ original = empty_fake(self)
121
+ original.data = target_original
122
+ cuda_aliases[self] = original
123
+ elif self in cuda_aliases:
124
+ del cuda_aliases[self]
125
+ self.data = target
126
+ return
127
+
128
+ if func == torch.Tensor.device.__get__:
129
+ tensor, = args
130
+ if tensor in cuda_aliases:
131
+ return torch.device('cuda', index=0)
132
+
133
+ elif func == torch.Tensor.__repr__:
134
+ tensor, = args
135
+ if tensor in cuda_aliases:
136
+ if (original := cuda_aliases[tensor]) is None:
137
+ original = tensor.to('meta')
138
+ original_class = original.__class__
139
+ original.__class__ = ZeroGPUTensor
140
+ try:
141
+ return func(original, **kwargs)
142
+ finally:
143
+ original.__class__ = original_class
144
+
145
+ elif func == torch.Tensor.untyped_storage:
146
+ tensor, = args
147
+ if tensor in cuda_aliases:
148
+ if (original := cuda_aliases[tensor]) is None:
149
+ raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor))
150
+ res = func(original, **kwargs)
151
+ res._zerogpu = True
152
+ return res
153
+
154
+ cuda: bool | None = None
155
+
156
+ # Handle device kwarg
157
+ if (device := kwargs.get('device')) is not None:
158
+ device = torch.device(device)
159
+ if device.type == 'cuda':
160
+ kwargs['device'] = torch.device('cpu')
161
+ cuda = True
162
+ else:
163
+ cuda = False
164
+
165
+ # Swap fake inputs with original data
166
+ swapped = {}
167
+ inputs_are_cuda = set()
168
+ def swap(tensor: torch.Tensor):
169
+ nonlocal inputs_are_cuda
170
+ if tensor not in cuda_aliases:
171
+ inputs_are_cuda |= {False}
172
+ return tensor
173
+ if (original := cuda_aliases[tensor]) is None:
174
+ raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor))
175
+ swapped[original] = tensor
176
+ inputs_are_cuda |= {True}
177
+ return original
178
+ args_ = tree_map_only(torch.Tensor, swap, args)
179
+ kwargs_ = tree_map_only(torch.Tensor, swap, kwargs)
180
+ if inputs_are_cuda == {True}:
181
+ if cuda is not False:
182
+ cuda = True
183
+
184
+ res = func(*args_, **kwargs_)
185
+
186
+ # Re-generate swapped fakes in case of mutation
187
+ for original, fake in swapped.items():
188
+ fake.data = empty_fake(original)
189
+
190
+ # Special case for Tensor indexing where only 'self' matters
191
+ if func in {
192
+ torch.ops.aten.index.Tensor, # pyright: ignore [reportAttributeAccessIssue]
193
+ torch.Tensor.__getitem__, # PyTorch 2.4+
194
+ }:
195
+ self = args[0]
196
+ cuda = self in cuda_aliases
197
+ inputs_are_cuda = {cuda}
198
+
199
+ # Emulate device check
200
+ if isinstance(res, torch.Tensor) or func in OPS_INPUTS_CHECK_NO_RETURN:
201
+ self = None
202
+ if len(args_) >= 1 and isinstance(args_[0], torch.Tensor):
203
+ self = args_[0]
204
+ # Only raise if func does not return its first input (Tensor.copy_)
205
+ if res is not self or func in OPS_INPUT_CHECK_SELF_RETURN:
206
+ if inputs_are_cuda == {True, False}:
207
+ raise RuntimeError(
208
+ "Expected all tensors to be on the same device, "
209
+ "but found at least two devices, cuda:0 (ZeroGPU) and cpu!"
210
+ )
211
+
212
+ # Register output
213
+ def register(tensor: torch.Tensor):
214
+ if tensor in swapped and cuda is not False:
215
+ return swapped[tensor]
216
+ if cuda is not True:
217
+ return tensor
218
+ fake = empty_fake(tensor)
219
+ cuda_aliases[fake] = tensor
220
+ return fake
221
+
222
+ return tree_map_only(torch.Tensor, register, res)
223
+
224
+ # When enabling DispatchMode, some aten ops are dispatched to FunctionMode
225
+ # We are using it for aten.alias.default and aten.set_.source_Tensor
226
+ class DefaultDispatchMode(TorchDispatchMode):
227
+ def __torch_dispatch__(self, func, types, args=(), kwargs: dict[str, Any] | None = None):
228
+ return func(*args, **(kwargs or {}))
229
+
230
+
231
+ function_mode = ZeroGPUFunctionMode()
232
+ dispatch_mode = DefaultDispatchMode()
233
+
234
+
235
+ def _untyped_storage_new_register(*args, **kwargs):
236
+ cuda = False
237
+ if (device := kwargs.get('device')) is not None and device.type == 'cuda':
238
+ cuda = True
239
+ del kwargs['device']
240
+ storage = torch._C.StorageBase.__new__(*args, **kwargs)
241
+ if cuda:
242
+ storage._zerogpu = True
243
+ return storage
244
+
245
+ @property
246
+ def _untyped_storage_device(self):
247
+ if hasattr(self, '_zerogpu'):
248
+ return torch.device('cuda', index=0)
249
+ return torch._C.StorageBase.device.__get__(self) # pyright: ignore [reportAttributeAccessIssue]
250
+
251
+ # Force dispatch
252
+ def _tensor_make_subclass_function_mode(*args, **kwargs):
253
+ with torch._C.DisableTorchFunction():
254
+ return function_mode.__torch_function__(_tensor_make_subclass, (), args=args, kwargs=kwargs)
255
+ def _asarray_function_mode(*args, **kwargs):
256
+ with torch._C.DisableTorchFunction():
257
+ return function_mode.__torch_function__(_asarray, (), args=args, kwargs=kwargs)
258
+
259
+ class _DeviceStringOnlyMeta(type):
260
+ def __instancecheck__(cls, instance):
261
+ return isinstance(instance, _device)
262
+
263
+ class _DeviceStringOnly(metaclass=_DeviceStringOnlyMeta):
264
+ def __new__(cls, *args, **kwargs):
265
+ args, kwargs = no_int_device(*args, **kwargs)
266
+ return _device(*args, **kwargs)
267
+
268
+ def _cuda_init_raise():
269
+ raise RuntimeError(
270
+ "CUDA must not be initialized in the main process "
271
+ "on Spaces with Stateless GPU environment.\n"
272
+ "You can look at this Stacktrace to find out "
273
+ "which part of your code triggered a CUDA init"
274
+ )
275
+
276
+ def _cuda_dummy_exchange_device(device):
277
+ assert device in {-1, 0}
278
+ return device
279
+
280
+ def patch():
281
+ function_mode.__enter__()
282
+ dispatch_mode.__enter__()
283
+ # TODO: only patch bellow methods on current Thread to be consistent with TorchModes
284
+ # (or hijack threading.Thread.__init__ to force Modes on all threads)
285
+ torch.Tensor._make_subclass = _tensor_make_subclass_function_mode # pyright: ignore [reportAttributeAccessIssue]
286
+ torch.UntypedStorage.__new__ = _untyped_storage_new_register
287
+ torch.UntypedStorage.device = _untyped_storage_device # pyright: ignore [reportAttributeAccessIssue]
288
+ torch.asarray = _asarray_function_mode
289
+ torch.device = _DeviceStringOnly
290
+ torch._C._cuda_init = _cuda_init_raise
291
+ torch.cuda._exchange_device = _cuda_dummy_exchange_device
292
+ torch.cuda.is_available = lambda: True
293
+ torch.cuda.device_count = lambda: 1
294
+ torch.cuda.current_device = lambda: 0
295
+ torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY
296
+ torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES
297
+ torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME
298
+ torch.cuda.memory.memory_stats_as_nested_dict = lambda *args, **kwargs: CUDA_MEMORY_STATS_AS_NESTED_DICT
299
+ torch.cuda.cudart = lambda: cudart
300
+ # PyTorch 2.3
301
+ if _cuda_maybe_exchange_device is not None: # pragma: no cover
302
+ setattr(torch.cuda, '_maybe_exchange_device', _cuda_dummy_exchange_device)
303
+ bitsandbytes().patch()
304
+
305
+ def unpatch():
306
+ try:
307
+ dispatch_mode.__exit__(None, None, None)
308
+ function_mode.__exit__(None, None, None)
309
+ except RuntimeError:
310
+ pass # patch() and unpatch() called from != threads
311
+ torch.Tensor._make_subclass = _tensor_make_subclass
312
+ torch.UntypedStorage.__new__ = torch._C.StorageBase.__new__
313
+ torch.UntypedStorage.device = torch._C.StorageBase.device # pyright: ignore [reportAttributeAccessIssue]
314
+ torch.asarray = _asarray
315
+ torch.device = _device
316
+ torch._C._cuda_init = _cuda_init
317
+ torch.cuda._exchange_device = _cuda_exchange_device
318
+ torch.cuda.is_available = _cuda_available
319
+ torch.cuda.device_count = _cuda_device_count
320
+ torch.cuda.current_device = _cuda_current_device
321
+ torch.cuda.get_device_capability = _cuda_get_device_capability
322
+ torch.cuda.get_device_properties = _cuda_get_device_properties
323
+ torch.cuda.get_device_name = _cuda_get_device_name
324
+ torch.cuda.memory.memory_stats_as_nested_dict = _cuda_memory_stats_as_nested_dict
325
+ torch.cuda.cudart = _cuda_cudart
326
+ # PyTorch 2.3
327
+ if _cuda_maybe_exchange_device is not None: # pragma: no cover
328
+ setattr(torch.cuda, '_maybe_exchange_device', _cuda_exchange_device)
329
+ bitsandbytes().unpatch()
330
+
331
+
332
+ def _total_unpacked_size():
333
+ tensors = [tensor for tensor in cuda_aliases.values() if tensor is not None]
334
+ deduped = {AliasId.from_tensor(tensor): tensor for tensor in tensors}
335
+ return sum([tensor.numel() * tensor.element_size() for tensor in deduped.values()])
336
+
337
+
338
+ def _pack(offload_dir: str):
339
+ # Pack to disk
340
+ originals: set[torch.Tensor] = set()
341
+ originals_dedup: dict[AliasId, torch.Tensor] = {}
342
+ fakes: dict[torch.Tensor, list[torch.Tensor]] = defaultdict(list)
343
+ for fake, original in cuda_aliases.items():
344
+ # TODO filter-out sparse Tensors
345
+ if original is not None:
346
+ original_id = AliasId.from_tensor(original)
347
+ if original_id not in originals_dedup:
348
+ originals_dedup[original_id] = original
349
+ originals |= {original}
350
+ fakes[originals_dedup[original_id]] += [fake]
351
+ progress = tqdm(
352
+ total=_total_unpacked_size(),
353
+ unit='B',
354
+ unit_scale=True,
355
+ desc="ZeroGPU tensors packing",
356
+ ) if tqdm is not None else nullcontext()
357
+ with progress as progress:
358
+ update = progress.update if progress is not None else lambda _: None
359
+ pack = pack_tensors(originals, fakes, offload_dir, callback=update)
360
+ tensor_packs.append(pack)
361
+ # Free memory
362
+ for fake_list in fakes.values():
363
+ for fake in fake_list:
364
+ cuda_aliases[fake] = None
365
+
366
+ def pack():
367
+ _pack(Config.zerogpu_offload_dir)
368
+ gc.collect()
369
+ malloc_trim()
370
+
371
+ def init(nvidia_uuid: str):
372
+ os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
373
+ torch.Tensor([0]).cuda()
374
+
375
+ def size():
376
+ return _total_unpacked_size() + sum([pack.total_size for pack in tensor_packs])
377
+
378
+ def _move(callback: Callable[[int]] | None = None):
379
+ callback = callback if callback is not None else lambda _: None
380
+ # CPU -> CUDA
381
+ moved: dict[AliasId, torch.Tensor] = {}
382
+ for fake, original in cuda_aliases.items():
383
+ if original is not None:
384
+ original_id = AliasId.from_tensor(original)
385
+ if original_id not in moved:
386
+ moved[original_id] = original.cuda()
387
+ callback(fake.numel() * fake.element_size())
388
+ for fake, original in cuda_aliases.items():
389
+ if original is not None:
390
+ fake.data = moved[AliasId.from_tensor(original)]
391
+ # Disk -> CUDA
392
+ for tensor_pack in tensor_packs:
393
+ pack_to_cuda(tensor_pack, callback=callback)
394
+ bitsandbytes().move()
395
+
396
+ def move(callback: Callable[[int]] | None = None):
397
+ callback = callback if callback is not None else lambda _: None
398
+ with ThreadPoolExecutor(1) as e:
399
+ e.submit(copy_context().run, _move, callback=callback).result()
400
+ torch.cuda.synchronize()
401
+
402
+ def is_in_bad_fork():
403
+ with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
404
+ f = e.submit(torch.cuda._is_in_bad_fork)
405
+ return f.result()
406
+
407
+ def bitsandbytes():
408
+ # Lazy import
409
+ from . import bitsandbytes
410
+ return bitsandbytes
spaces/zero/torch/patching_legacy.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ # pyright: reportPrivateImportUsage=false
4
+
5
+ from __future__ import annotations
6
+
7
+ import multiprocessing
8
+ import os
9
+ from concurrent.futures import ProcessPoolExecutor
10
+ from contextlib import suppress
11
+ from functools import partial
12
+ from types import SimpleNamespace
13
+ from typing import Any
14
+ from typing import Callable
15
+ from typing import Optional
16
+ from typing import Tuple
17
+
18
+ import torch
19
+ from torch.utils.weak import WeakTensorKeyDictionary
20
+
21
+ from ...config import Config
22
+ from . import bitsandbytes
23
+
24
+
25
+ # Nvidia A100.80G MIG (drivers 535) / Torch 2.2.0
26
+ CUDA_DEVICE_NAME = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb'
27
+ CUDA_TOTAL_MEMORY = 42144366592
28
+ CUDA_MEM_GET_INFO = (41911451648, CUDA_TOTAL_MEMORY)
29
+ CUDA_DEVICE_CAPABILITY = (8, 0)
30
+ CUDA_DEVICE_PROPERTIES = SimpleNamespace(name=CUDA_DEVICE_NAME, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY, multi_processor_count=42)
31
+
32
+ GENERIC_METHOD_NAMES = [
33
+ 'arange',
34
+ 'as_tensor',
35
+ 'asarray',
36
+ 'bartlett_window',
37
+ 'blackman_window',
38
+ 'empty',
39
+ 'empty_like',
40
+ 'empty_strided',
41
+ 'eye',
42
+ 'full',
43
+ 'full_like',
44
+ 'hamming_window',
45
+ 'hann_window',
46
+ 'kaiser_window',
47
+ 'linspace',
48
+ 'logspace',
49
+ 'ones',
50
+ 'ones_like',
51
+ 'rand',
52
+ 'rand_like',
53
+ 'randint',
54
+ 'randint_like',
55
+ 'randn',
56
+ 'randn_like',
57
+ 'randperm',
58
+ 'range',
59
+ 'sparse_bsc_tensor',
60
+ 'sparse_bsr_tensor',
61
+ 'sparse_compressed_tensor',
62
+ 'sparse_coo_tensor',
63
+ 'sparse_csc_tensor',
64
+ 'sparse_csr_tensor',
65
+ 'tensor',
66
+ 'tril_indices',
67
+ 'triu_indices',
68
+ 'zeros',
69
+ 'zeros_like',
70
+ ]
71
+
72
+
73
+ TO_CUDA = (torch.device('cuda'), None, False, None)
74
+
75
+ _tensor__deepcopy__ = torch.Tensor.__deepcopy__
76
+ _tensor_to = torch.Tensor.to
77
+ _tensor_cuda = torch.Tensor.cuda
78
+ _tensor_cpu = torch.Tensor.cpu
79
+ _torch_generics = {name: getattr(torch, name) for name in GENERIC_METHOD_NAMES}
80
+ _cuda_init = torch._C._cuda_init
81
+ _cuda_available = torch.cuda.is_available
82
+ _cuda_device_count = torch.cuda.device_count
83
+ _cuda_current_device = torch.cuda.current_device
84
+ _cuda_mem_get_info = torch.cuda.mem_get_info
85
+ _cuda_get_device_capability = torch.cuda.get_device_capability
86
+ _cuda_get_device_properties = torch.cuda.get_device_properties
87
+ _cuda_get_device_name = torch.cuda.get_device_name
88
+
89
+ TensorToArgs = Tuple[Optional[torch.device], Optional[torch.dtype], bool, Optional[torch.memory_format]]
90
+
91
+ to_ops: dict[torch.Tensor, TensorToArgs] = WeakTensorKeyDictionary() # type: ignore
92
+
93
+ def _tensor_new_register(*args, **kwargs):
94
+ new_tensor: torch.Tensor = torch._C._TensorBase.__new__(*args, **kwargs)
95
+ if (base_tensor := new_tensor._base) is not None:
96
+ if base_tensor in to_ops:
97
+ to_ops[new_tensor] = to_ops[base_tensor]
98
+ return new_tensor
99
+
100
+ def _tensor_deepcopy_register(self: torch.Tensor, memo):
101
+ new_tensor = _tensor__deepcopy__(self, memo)
102
+ if isinstance(new_tensor, torch.Tensor):
103
+ if self in to_ops:
104
+ to_ops[new_tensor] = to_ops[self]
105
+ return new_tensor
106
+
107
+ @property
108
+ def _tensor_device_property(self: torch.Tensor):
109
+ if self in to_ops:
110
+ return torch.device(type='cuda', index=0)
111
+ del torch.Tensor.device
112
+ try:
113
+ return self.device
114
+ finally:
115
+ torch.Tensor.device = _tensor_device_property # type: ignore
116
+
117
+ @property
118
+ def _tensor_dtype_property(self: torch.Tensor):
119
+ if self in to_ops:
120
+ if (to_dtype := to_ops[self][1]) is not None:
121
+ return to_dtype
122
+ del torch.Tensor.dtype
123
+ try:
124
+ return self.dtype
125
+ finally:
126
+ torch.Tensor.dtype = _tensor_dtype_property # type: ignore
127
+
128
+ def _to_op_register(self: torch.Tensor, *args, **kwargs):
129
+ parsed = torch._C._nn._parse_to(*args, **kwargs)
130
+ device, dtype, *_ = parsed
131
+ try:
132
+ to_args = to_ops.pop(self)
133
+ except KeyError:
134
+ to_args = None
135
+ if device is None: # pyright: ignore [reportUnnecessaryComparison]
136
+ if to_args is not None:
137
+ to_ops[self] = (to_args[0], dtype, *to_args[2:])
138
+ return self
139
+ return _tensor_to(self, *args, **kwargs)
140
+ if device.type != 'cuda':
141
+ if to_args is not None:
142
+ if (to_dtype := to_args[1]) is not None:
143
+ kwargs = {'dtype': to_dtype, **kwargs}
144
+ return _tensor_to(self, *args, **kwargs)
145
+ to_ops[self] = parsed
146
+ return self
147
+
148
+ def _cuda_op_arg_check(device: torch.device | int | str | None) -> bool:
149
+ if device is None:
150
+ return True
151
+ if isinstance(device, int):
152
+ return True
153
+ if isinstance(device, str):
154
+ device = torch.device(device)
155
+ return device.type == 'cuda'
156
+
157
+ def _cuda_op_register(self: torch.Tensor, device: torch.device | int | str | None = None, **kwargs):
158
+ if not _cuda_op_arg_check(device):
159
+ # Let PyTorch handle the fail
160
+ return _tensor_cuda(self, device, **kwargs)
161
+ to_ops[self] = TO_CUDA
162
+ return self
163
+
164
+ def _cpu_op_remove(self: torch.Tensor, **kwargs):
165
+ try:
166
+ to_args = to_ops.pop(self)
167
+ except KeyError:
168
+ to_args = None
169
+ if to_args is not None:
170
+ if (to_dtype := to_args[1]) is not None:
171
+ return _tensor_to(self, 'cpu', **{'dtype': to_dtype, **kwargs})
172
+ return _tensor_cpu(self, **kwargs)
173
+
174
+ def _cuda_init_raise():
175
+ raise RuntimeError(
176
+ "CUDA must not be initialized in the main process "
177
+ "on Spaces with Stateless GPU environment.\n"
178
+ "You can look at this Stacktrace to find out "
179
+ "which part of your code triggered a CUDA init"
180
+ )
181
+
182
+ def _generic_method_register(name: str, *args: Any, **kwargs: Any):
183
+ try:
184
+ device = torch.device(kwargs.get('device', "cpu"))
185
+ except Exception:
186
+ return _torch_generics[name](*args, **kwargs)
187
+ if device.type != 'cuda':
188
+ return _torch_generics[name](*args, **kwargs)
189
+ tensor = _torch_generics[name](*args, **{**kwargs, 'device': "cpu"})
190
+ to_ops[tensor] = TO_CUDA
191
+ return tensor
192
+
193
+ def patch():
194
+ torch.Tensor.__deepcopy__ = _tensor_deepcopy_register
195
+ torch.Tensor.__new__ = _tensor_new_register # pyright: ignore [reportAttributeAccessIssue]
196
+ torch.Tensor.to = _to_op_register # type: ignore
197
+ torch.Tensor.cuda = _cuda_op_register # type: ignore
198
+ torch.Tensor.cpu = _cpu_op_remove # type: ignore
199
+ if Config.zero_patch_torch_device:
200
+ torch.Tensor.device = _tensor_device_property # type: ignore
201
+ torch.Tensor.dtype = _tensor_dtype_property # pyright: ignore [reportAttributeAccessIssue]
202
+ for name in GENERIC_METHOD_NAMES:
203
+ setattr(torch, name, partial(_generic_method_register, name))
204
+ torch._C._cuda_init = _cuda_init_raise
205
+ torch.cuda.is_available = lambda: True
206
+ torch.cuda.device_count = lambda: 1
207
+ torch.cuda.current_device = lambda: 0
208
+ torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO
209
+ torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY
210
+ torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES
211
+ torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME
212
+ bitsandbytes.patch()
213
+
214
+ def unpatch():
215
+ torch.Tensor.__deepcopy__ = _tensor__deepcopy__
216
+ with suppress(AttributeError):
217
+ del torch.Tensor.__new__
218
+ torch.Tensor.to = _tensor_to
219
+ torch.Tensor.cuda = _tensor_cuda
220
+ torch.Tensor.cpu = _tensor_cpu
221
+ with suppress(AttributeError):
222
+ del torch.Tensor.device
223
+ with suppress(AttributeError):
224
+ del torch.Tensor.dtype
225
+ for name in GENERIC_METHOD_NAMES:
226
+ setattr(torch, name, _torch_generics[name])
227
+ torch._C._cuda_init = _cuda_init
228
+ torch.cuda.is_available = _cuda_available
229
+ torch.cuda.device_count = _cuda_device_count
230
+ torch.cuda.current_device = _cuda_current_device
231
+ torch.cuda.mem_get_info = _cuda_mem_get_info
232
+ torch.cuda.get_device_capability = _cuda_get_device_capability
233
+ torch.cuda.get_device_properties = _cuda_get_device_properties
234
+ torch.cuda.get_device_name = _cuda_get_device_name
235
+ bitsandbytes.unpatch()
236
+
237
+ def pack():
238
+ pass
239
+
240
+ def init(nvidia_uuid: str):
241
+ os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
242
+ torch.Tensor([0]).cuda() # CUDA init
243
+
244
+ def size():
245
+ return 0
246
+
247
+ def move(callback: Callable[[int]] | None = None):
248
+ for op in to_ops.items():
249
+ tensor, parsed_args = op
250
+ _, dtype, _, memory_format = parsed_args
251
+ tensor.data = _tensor_to(tensor,
252
+ device='cuda',
253
+ dtype=dtype,
254
+ memory_format=memory_format,
255
+ ) # type: ignore
256
+ bitsandbytes.move()
257
+ torch.cuda.synchronize()
258
+
259
+ def is_in_bad_fork():
260
+ with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
261
+ f = e.submit(torch.cuda._is_in_bad_fork)
262
+ return f.result()
263
+
264
+ def disable_cuda_intercept():
265
+ torch.Tensor.to = _tensor_to
266
+ torch.Tensor.cuda = _tensor_cuda
spaces/zero/torch/static.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from types import SimpleNamespace as _SimpleNamespace
5
+
6
+ import torch as _torch
7
+
8
+ from ...config import Config
9
+
10
+
11
+ def compute_base_free_memory(total_memory: int):
12
+ pytorch_base_memory = 309002240 # TODO: fine-grain per: torch-version x GPU(-MIG) model
13
+ return total_memory - pytorch_base_memory - Config.zerogpu_cuda_reserved_memory
14
+
15
+ CUDA_DEVICE_NAME = Config.zerogpu_cuda_device_name
16
+ CUDA_TOTAL_MEMORY = Config.zerogpu_cuda_total_memory
17
+ CUDA_MEM_GET_INFO = (compute_base_free_memory(CUDA_TOTAL_MEMORY), CUDA_TOTAL_MEMORY)
18
+ CUDA_DEVICE_CAPABILITY = (Config.zerogpu_cuda_capability_major, Config.zerogpu_cuda_capability_minor)
19
+ CUDA_DEVICE_PROPERTIES = _SimpleNamespace(
20
+ name=CUDA_DEVICE_NAME,
21
+ major=CUDA_DEVICE_CAPABILITY[0],
22
+ minor=CUDA_DEVICE_CAPABILITY[1],
23
+ total_memory=CUDA_TOTAL_MEMORY,
24
+ multi_processor_count=Config.zerogpu_cuda_multi_processor_count,
25
+ # TODO: L2_cache_size
26
+ )
27
+
28
+ if _torch.version.cuda.startswith("12."): # pyright: ignore [reportAttributeAccessIssue]
29
+ CUDA_MEMORY_STATS_AS_NESTED_DICT = {
30
+ "num_alloc_retries": 0,
31
+ "num_ooms": 0,
32
+ "max_split_size": -1,
33
+ "num_sync_all_streams": 0,
34
+ "num_device_alloc": 0,
35
+ "num_device_free": 0,
36
+ "allocation": {
37
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
38
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
39
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
40
+ },
41
+ "segment": {
42
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
43
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
44
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
45
+ },
46
+ "active": {
47
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
48
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
49
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
50
+ },
51
+ "inactive_split": {
52
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
53
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
54
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
55
+ },
56
+ "allocated_bytes": {
57
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
58
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
59
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
60
+ },
61
+ "reserved_bytes": {
62
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
63
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
64
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
65
+ },
66
+ "active_bytes": {
67
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
68
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
69
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
70
+ },
71
+ "inactive_split_bytes": {
72
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
73
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
74
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
75
+ },
76
+ "requested_bytes": {
77
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
78
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
79
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
80
+ },
81
+ "oversize_allocations": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
82
+ "oversize_segments": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
83
+ }
84
+ else: # pragma: no cover (CUDA 11)
85
+ CUDA_MEMORY_STATS_AS_NESTED_DICT = {
86
+ "num_alloc_retries": 0,
87
+ "num_ooms": 0,
88
+ "max_split_size": -1,
89
+ "allocation": {
90
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
91
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
92
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
93
+ },
94
+ "segment": {
95
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
96
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
97
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
98
+ },
99
+ "active": {
100
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
101
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
102
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
103
+ },
104
+ "inactive_split": {
105
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
106
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
107
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
108
+ },
109
+ "allocated_bytes": {
110
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
111
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
112
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
113
+ },
114
+ "reserved_bytes": {
115
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
116
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
117
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
118
+ },
119
+ "active_bytes": {
120
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
121
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
122
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
123
+ },
124
+ "inactive_split_bytes": {
125
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
126
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
127
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
128
+ },
129
+ "requested_bytes": {
130
+ "all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
131
+ "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
132
+ "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
133
+ },
134
+ "oversize_allocations": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
135
+ "oversize_segments": {"current": 0, "peak": 0, "allocated": 0, "freed": 0},
136
+ }
spaces/zero/torch/types.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ from typing import NamedTuple
6
+
7
+ import torch
8
+
9
+
10
+ class AliasId(NamedTuple):
11
+ data_ptr: int
12
+ dtype: torch.dtype
13
+ shape: tuple[int, ...]
14
+ stride: tuple[int, ...]
15
+
16
+ @classmethod
17
+ def from_tensor(cls, tensor: torch.Tensor):
18
+ return cls(
19
+ tensor.data_ptr(),
20
+ tensor.dtype,
21
+ tensor.shape,
22
+ tensor.stride(),
23
+ )
spaces/zero/tqdm.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from multiprocessing.synchronize import RLock as MultiprocessingRLock
5
+
6
+
7
+ try:
8
+ from tqdm import tqdm as _tqdm
9
+ except ImportError: # pragma: no cover
10
+ _tqdm = None
11
+
12
+
13
+ def remove_tqdm_multiprocessing_lock():
14
+ if _tqdm is None: # pragma: no cover
15
+ return
16
+ tqdm_lock = _tqdm.get_lock()
17
+ assert tqdm_lock.__class__.__name__ == 'TqdmDefaultWriteLock'
18
+ tqdm_lock.locks = [
19
+ lock for lock in tqdm_lock.locks
20
+ if not isinstance(lock, MultiprocessingRLock)
21
+ ]
22
+
23
+
24
+ tqdm = _tqdm
spaces/zero/types.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+
6
+ from dataclasses import dataclass
7
+ from datetime import timedelta
8
+ from typing import Any
9
+ from typing import Dict
10
+ from typing import Tuple
11
+ from typing import TypedDict
12
+ from typing_extensions import Callable
13
+ from typing_extensions import Generic
14
+ from typing_extensions import ParamSpec
15
+ from typing_extensions import TypeAlias
16
+ from typing_extensions import TypeVar
17
+
18
+
19
+ Params = Tuple[Tuple[object, ...], Dict[str, Any]]
20
+ Res = TypeVar('Res')
21
+ Param = ParamSpec('Param')
22
+
23
+ class EmptyKwargs(TypedDict):
24
+ pass
25
+
26
+ @dataclass
27
+ class OkResult(Generic[Res]):
28
+ value: Res
29
+ @dataclass
30
+ class ExceptionResult:
31
+ traceback: str
32
+ error_cls: str
33
+ @dataclass
34
+ class AbortedResult:
35
+ pass
36
+ @dataclass
37
+ class EndResult:
38
+ pass
39
+ @dataclass
40
+ class GradioQueueEvent:
41
+ method_name: str
42
+ args: tuple[Any, ...]
43
+ kwargs: dict[str, Any]
44
+
45
+ RegularResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | GradioQueueEvent"
46
+ GeneratorResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | GradioQueueEvent"
47
+ YieldQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | AbortedResult"
48
+
49
+ Duration: TypeAlias = "int | timedelta"
50
+ DynamicDuration: TypeAlias = "Duration | Callable[Param, Duration] | None"
spaces/zero/wrappers.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import multiprocessing
6
+ import os
7
+ import signal
8
+ import traceback
9
+ import warnings
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from contextlib import nullcontext
12
+ from contextvars import copy_context
13
+ from datetime import timedelta
14
+ from functools import partial
15
+ from functools import wraps
16
+ from multiprocessing.context import ForkProcess
17
+ from pickle import PicklingError
18
+ from queue import Empty
19
+ from queue import Queue as ThreadQueue
20
+ from threading import Thread
21
+ from typing import TYPE_CHECKING
22
+ from typing import Callable
23
+ from typing import Generator
24
+ from typing import Generic
25
+ from typing_extensions import assert_never
26
+
27
+ import psutil
28
+
29
+ from ..config import Config
30
+ from ..utils import debug
31
+ from ..utils import drop_params
32
+ from ..utils import gradio_request_var
33
+ from ..utils import SimpleQueue as Queue
34
+ from . import client
35
+ from . import torch
36
+ from .api import AllowToken
37
+ from .api import NvidiaIndex
38
+ from .api import NvidiaUUID
39
+ from .gradio import GradioPartialContext
40
+ from .gradio import error
41
+ from .gradio import get_server_port
42
+ from .gradio import patch_gradio_queue
43
+ from .gradio import try_process_queue_event
44
+ from .tqdm import remove_tqdm_multiprocessing_lock
45
+ from .tqdm import tqdm
46
+ from .types import * # TODO: Please don't do that
47
+
48
+
49
+ GENERATOR_GLOBAL_TIMEOUT = 20 * 60
50
+
51
+ SPAWN_PROGRESS_CLEANUP = 0.1
52
+ SPAWN_PROGRESS_INIT = 0.1
53
+
54
+
55
+ Process = multiprocessing.get_context('fork').Process
56
+ forked = False
57
+
58
+
59
+ class Worker(Generic[Res]):
60
+ process: ForkProcess
61
+ arg_queue: Queue[tuple[Params, GradioPartialContext]]
62
+ res_queue: Queue[Res | None]
63
+ _sentinel: Thread
64
+
65
+ def __init__(
66
+ self,
67
+ target: Callable[[
68
+ Queue[tuple[Params, GradioPartialContext]],
69
+ Queue[Res | None],
70
+ AllowToken,
71
+ NvidiaUUID,
72
+ list[int],
73
+ ], None],
74
+ allow_token: str,
75
+ nvidia_uuid: str,
76
+ ):
77
+ self._sentinel = Thread(target=self._close_on_exit, daemon=True)
78
+ self.arg_queue = Queue()
79
+ self.res_queue = Queue()
80
+ debug(f"{self.arg_queue._writer.fileno()=}") # pyright: ignore [reportAttributeAccessIssue]
81
+ debug(f"{self.res_queue._writer.fileno()=}") # pyright: ignore [reportAttributeAccessIssue]
82
+ if (server_port := get_server_port()) is not None:
83
+ fds = [c.fd for c in psutil.Process().connections() if c.laddr.port == server_port]
84
+ debug(f"{fds=}")
85
+ else:
86
+ warnings.warn("Using a ZeroGPU function outside of Gradio caching or request might block the app")
87
+ fds = []
88
+ args = self.arg_queue, self.res_queue, allow_token, nvidia_uuid, fds
89
+ if TYPE_CHECKING:
90
+ target(*args)
91
+ self.process = Process(
92
+ target=target,
93
+ args=args,
94
+ daemon=True,
95
+ )
96
+ self.process.start()
97
+ self._sentinel.start()
98
+
99
+ def _close_on_exit(self):
100
+ self.process.join()
101
+ self.arg_queue.close()
102
+ self.res_queue.wlock_release()
103
+ self.res_queue.put(None)
104
+
105
+
106
+ def worker_init(
107
+ res_queue: Queue[RegularResQueueResult | None] | Queue[GeneratorResQueueResult | None],
108
+ allow_token: str,
109
+ nvidia_uuid: str,
110
+ fds: list[int],
111
+ ) -> None | ExceptionResult:
112
+ # Immediately close file descriptors
113
+ for fd in fds:
114
+ try:
115
+ os.close(fd)
116
+ except Exception as e: # pragma: no cover
117
+ if isinstance(e, OSError) and e.errno == 9:
118
+ continue
119
+ return exception_result(e)
120
+ try:
121
+ remove_tqdm_multiprocessing_lock()
122
+ except Exception: # pragma: no cover
123
+ print("Error while trying to remove tqdm mp_lock:")
124
+ traceback.print_exc()
125
+ progress = nullcontext()
126
+ if tqdm is not None and Config.zero_gpu_v2:
127
+ progress = tqdm(total=100, desc="ZeroGPU init", file=open(os.devnull, 'w'))
128
+ try: # Unrecoverable init part
129
+ patch_gradio_queue(res_queue)
130
+ with progress as progress:
131
+ current_progress = 0 # Gradio does not support float progress updates
132
+ def update(n: float):
133
+ nonlocal current_progress
134
+ current_progress += n
135
+ if progress is not None:
136
+ progress.update(round(current_progress * 100) - progress.n)
137
+ client.allow(allow_token)
138
+ update(SPAWN_PROGRESS_CLEANUP)
139
+ torch.unpatch()
140
+ torch.init(nvidia_uuid)
141
+ update(SPAWN_PROGRESS_INIT)
142
+ callback = None
143
+ if (transfer_size := torch.size()) > 0:
144
+ remaining = 1 - (SPAWN_PROGRESS_CLEANUP + SPAWN_PROGRESS_INIT)
145
+ callback = lambda n: update(n * remaining / transfer_size)
146
+ torch.move(callback=callback)
147
+ except Exception as e: # pragma: no cover
148
+ return exception_result(e)
149
+
150
+
151
+ def process_duration(duration: Duration | None):
152
+ if duration is None or isinstance(duration, timedelta):
153
+ return duration
154
+ return timedelta(seconds=duration)
155
+
156
+
157
+ def static_duration(duration: DynamicDuration[Param], *args: Param.args, **kwargs: Param.kwargs):
158
+ if not callable(duration):
159
+ return duration
160
+ return duration(*args, **kwargs)
161
+
162
+
163
+ def regular_function_wrapper(
164
+ task: Callable[Param, Res],
165
+ duration: DynamicDuration[Param],
166
+ ) -> Callable[Param, Res]:
167
+
168
+ import gradio as gr
169
+
170
+ request_var = gradio_request_var()
171
+ workers: dict[NvidiaIndex, Worker[RegularResQueueResult[Res] | None]] = {}
172
+ task_id = id(task)
173
+
174
+ @wraps(task)
175
+ def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Res:
176
+
177
+ if forked:
178
+ return task(*args, **kwargs)
179
+
180
+ request = request_var.get()
181
+ duration_ = static_duration(duration, *args, **kwargs)
182
+ duration_ = process_duration(duration_)
183
+ schedule_response = client.schedule(task_id=task_id, request=request, duration=duration_)
184
+ allow_token = schedule_response.allowToken
185
+ nvidia_index = schedule_response.nvidiaIndex
186
+ nvidia_uuid = schedule_response.nvidiaUUID
187
+ release = partial(client.release, allow_token)
188
+
189
+ try:
190
+ worker = workers.pop(nvidia_index)
191
+ except KeyError:
192
+ worker = None
193
+
194
+ if worker is not None and worker.process.is_alive() and schedule_response.idle:
195
+ assert worker.arg_queue.empty()
196
+ assert worker.res_queue.empty()
197
+ else:
198
+ worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
199
+
200
+ try:
201
+ worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
202
+ except PicklingError: # TODO: detailed serialization diagnostic
203
+ release(fail=True)
204
+ raise
205
+
206
+ while True:
207
+ res = worker.res_queue.get()
208
+ if res is None:
209
+ release(fail=True, allow_404=True)
210
+ raise error("ZeroGPU worker error", "GPU task aborted")
211
+ if isinstance(res, ExceptionResult):
212
+ release(fail=True)
213
+ print(res.traceback)
214
+ raise error("ZeroGPU worker error", res.error_cls)
215
+ if isinstance(res, OkResult):
216
+ release()
217
+ workers[nvidia_index] = worker
218
+ return res.value
219
+ if isinstance(res, GradioQueueEvent):
220
+ try_process_queue_event(res.method_name, *res.args, **res.kwargs)
221
+ continue
222
+ assert_never(res)
223
+
224
+
225
+ def thread_wrapper(
226
+ arg_queue: Queue[tuple[Params, GradioPartialContext]],
227
+ res_queue: Queue[RegularResQueueResult[Res] | None],
228
+ allow_token: str,
229
+ nvidia_uuid: str,
230
+ fds: list[int],
231
+ ):
232
+ global forked
233
+ forked = True
234
+ signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
235
+ initialized = False
236
+ while True:
237
+ try:
238
+ (args, kwargs), gradio_context = arg_queue.get()
239
+ except OSError:
240
+ break
241
+ if not initialized:
242
+ if (res := worker_init(
243
+ res_queue=res_queue,
244
+ allow_token=allow_token,
245
+ nvidia_uuid=nvidia_uuid,
246
+ fds=fds,
247
+ )) is not None:
248
+ res_queue.put(res)
249
+ return
250
+ initialized = True
251
+ GradioPartialContext.apply(gradio_context)
252
+ context = copy_context()
253
+ with ThreadPoolExecutor() as executor:
254
+ future = executor.submit(context.run, task, *args, **kwargs) # type: ignore
255
+ try:
256
+ res = future.result()
257
+ except Exception as e:
258
+ res = exception_result(e)
259
+ else:
260
+ res = OkResult(res)
261
+ try:
262
+ res_queue.put(res)
263
+ except PicklingError as e:
264
+ res_queue.put(exception_result(e))
265
+
266
+ # https://github.com/python/cpython/issues/91002
267
+ if not hasattr(task, '__annotations__'):
268
+ gradio_handler.__annotations__ = {}
269
+
270
+ return gradio_handler
271
+
272
+
273
+ def generator_function_wrapper(
274
+ task: Callable[Param, Generator[Res, None, None]],
275
+ duration: DynamicDuration[Param],
276
+ ) -> Callable[Param, Generator[Res, None, None]]:
277
+
278
+ import gradio as gr
279
+
280
+ request_var = gradio_request_var()
281
+ workers: dict[NvidiaIndex, Worker[GeneratorResQueueResult[Res] | None]] = {}
282
+ task_id = id(task)
283
+
284
+ @wraps(task)
285
+ def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Generator[Res, None, None]:
286
+
287
+ if forked:
288
+ yield from task(*args, **kwargs)
289
+ return
290
+
291
+ request = request_var.get()
292
+ duration_ = static_duration(duration, *args, **kwargs)
293
+ duration_ = process_duration(duration_)
294
+ schedule_response = client.schedule(task_id=task_id, request=request, duration=duration_)
295
+ allow_token = schedule_response.allowToken
296
+ nvidia_index = schedule_response.nvidiaIndex
297
+ nvidia_uuid = schedule_response.nvidiaUUID
298
+ release = partial(client.release, allow_token)
299
+
300
+ try:
301
+ worker = workers.pop(nvidia_index)
302
+ except KeyError:
303
+ worker = None
304
+
305
+ if worker is not None and worker.process.is_alive() and schedule_response.idle:
306
+ assert worker.arg_queue.empty()
307
+ assert worker.res_queue.empty()
308
+ else:
309
+ worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
310
+
311
+ try:
312
+ worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
313
+ except PicklingError: # TODO: detailed serialization diagnostic
314
+ release(fail=True)
315
+ raise
316
+
317
+ yield_queue: ThreadQueue[YieldQueueResult[Res]] = ThreadQueue()
318
+ def fill_yield_queue(worker: Worker[GeneratorResQueueResult[Res] | None]):
319
+ while True:
320
+ res = worker.res_queue.get()
321
+ if res is None:
322
+ release(fail=True, allow_404=True)
323
+ yield_queue.put(AbortedResult())
324
+ return
325
+ if isinstance(res, ExceptionResult):
326
+ release(fail=True)
327
+ yield_queue.put(res)
328
+ return
329
+ if isinstance(res, EndResult):
330
+ release()
331
+ workers[nvidia_index] = worker
332
+ yield_queue.put(EndResult())
333
+ return
334
+ if isinstance(res, OkResult):
335
+ yield_queue.put(OkResult(res.value))
336
+ continue
337
+ if isinstance(res, GradioQueueEvent): # pragma: no cover (not working properly on Gradio side)
338
+ try_process_queue_event(res.method_name, *res.args, **res.kwargs)
339
+ continue
340
+ debug(f"fill_yield_queue: assert_never({res=})")
341
+ assert_never(res)
342
+ from typing_extensions import assert_never
343
+ with ThreadPoolExecutor() as e:
344
+ f = e.submit(copy_context().run, fill_yield_queue, worker)
345
+ f.add_done_callback(lambda _: debug("fill_yield_queue DONE"))
346
+ while True:
347
+ try:
348
+ res = yield_queue.get(timeout=GENERATOR_GLOBAL_TIMEOUT)
349
+ except Empty: # pragma: no cover
350
+ debug(f"yield_queue TIMEOUT ({GENERATOR_GLOBAL_TIMEOUT=})")
351
+ raise
352
+ if isinstance(res, AbortedResult):
353
+ raise error("ZeroGPU worker error", "GPU task aborted")
354
+ if isinstance(res, ExceptionResult):
355
+ print(res.traceback)
356
+ raise error("ZeroGPU worker error", res.error_cls)
357
+ if isinstance(res, EndResult):
358
+ break
359
+ if isinstance(res, OkResult):
360
+ yield res.value
361
+ continue
362
+ debug(f"gradio_handler: assert_never({res=})")
363
+ assert_never(res)
364
+
365
+
366
+ def thread_wrapper(
367
+ arg_queue: Queue[tuple[Params, GradioPartialContext]],
368
+ res_queue: Queue[GeneratorResQueueResult[Res] | None],
369
+ allow_token: str,
370
+ nvidia_uuid: str,
371
+ fds: list[int],
372
+ ):
373
+ global forked
374
+ forked = True
375
+ signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
376
+ initialized = False
377
+ while True:
378
+ try:
379
+ (args, kwargs), gradio_context = arg_queue.get()
380
+ except OSError:
381
+ break
382
+ if not initialized:
383
+ if (res := worker_init(
384
+ res_queue=res_queue,
385
+ allow_token=allow_token,
386
+ nvidia_uuid=nvidia_uuid,
387
+ fds=fds,
388
+ )) is not None:
389
+ res_queue.put(res)
390
+ return
391
+ initialized = True
392
+ def iterate():
393
+ gen = task(*args, **kwargs) # type: ignore
394
+ while True:
395
+ try:
396
+ res = next(gen)
397
+ except StopIteration:
398
+ break
399
+ except Exception as e:
400
+ res_queue.put(exception_result(e))
401
+ break
402
+ try:
403
+ res_queue.put(OkResult(res))
404
+ except PicklingError as e:
405
+ res_queue.put(exception_result(e))
406
+ break
407
+ else:
408
+ continue
409
+ GradioPartialContext.apply(gradio_context)
410
+ with ThreadPoolExecutor() as executor:
411
+ executor.submit(copy_context().run, iterate)
412
+ res_queue.put(EndResult())
413
+
414
+ # https://github.com/python/cpython/issues/91002
415
+ if not hasattr(task, '__annotations__'):
416
+ gradio_handler.__annotations__ = {}
417
+
418
+ return gradio_handler
419
+
420
+
421
+ def exception_result(exc: Exception) -> ExceptionResult:
422
+ formatted = traceback.format_exception(type(exc), exc, exc.__traceback__)
423
+ return ExceptionResult(traceback=''.join(formatted), error_cls=exc.__class__.__name__)