Spaces:
Paused
Paused
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| from typing import Callable | |
| class Cache: | |
| """Caching reusable args for faster inference""" | |
| def __init__(self, disable=False, prefix="", cache=None): | |
| self.cache = cache if cache is not None else {} | |
| self.disable = disable | |
| self.prefix = prefix | |
| def __call__(self, key: str, fn: Callable): | |
| if self.disable: | |
| return fn() | |
| key = self.prefix + key | |
| try: | |
| result = self.cache[key] | |
| except KeyError: | |
| result = fn() | |
| self.cache[key] = result | |
| return result | |
| def namespace(self, namespace: str): | |
| return Cache( | |
| disable=self.disable, | |
| prefix=self.prefix + namespace + ".", | |
| cache=self.cache, | |
| ) | |
| def get(self, key: str): | |
| key = self.prefix + key | |
| return self.cache[key] | |