File size: 1,149 Bytes
			
			| 0402d19 b3a61e8 0402d19 00568c1 0402d19 00568c1 b3a61e8 0402d19 6dc68a6 0402d19 b3a61e8 00568c1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | """
helper utils for tests
"""
import os
import shutil
import tempfile
import unittest
from functools import wraps
from importlib.metadata import version
from pathlib import Path
def with_temp_dir(test_func):
    @wraps(test_func)
    def wrapper(*args, **kwargs):
        # Create a temporary directory
        temp_dir = tempfile.mkdtemp()
        try:
            # Pass the temporary directory to the test function
            test_func(*args, temp_dir=temp_dir, **kwargs)
        finally:
            # Clean up the directory after the test
            shutil.rmtree(temp_dir)
    return wrapper
def most_recent_subdir(path):
    base_path = Path(path)
    subdirectories = [d for d in base_path.iterdir() if d.is_dir()]
    if not subdirectories:
        return None
    subdir = max(subdirectories, key=os.path.getctime)
    return subdir
def require_torch_2_1_1(test_case):
    """
    Decorator marking a test that requires torch >= 2.1.1
    """
    def is_min_2_1_1():
        torch_version = version("torch")
        return torch_version >= "2.1.1"
    return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)
 |