Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
from setuptools import find_packages, setup | |
import os | |
import subprocess | |
import sys | |
import time | |
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension | |
from utils.misc import gpu_is_available | |
version_file = './basicsr/version.py' | |
def readme(): | |
with open('README.md', encoding='utf-8') as f: | |
content = f.read() | |
return content | |
def get_git_hash(): | |
def _minimal_ext_cmd(cmd): | |
# construct minimal environment | |
env = {} | |
for k in ['SYSTEMROOT', 'PATH', 'HOME']: | |
v = os.environ.get(k) | |
if v is not None: | |
env[k] = v | |
# LANGUAGE is used on win32 | |
env['LANGUAGE'] = 'C' | |
env['LANG'] = 'C' | |
env['LC_ALL'] = 'C' | |
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] | |
return out | |
try: | |
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) | |
sha = out.strip().decode('ascii') | |
except OSError: | |
sha = 'unknown' | |
return sha | |
def get_hash(): | |
if os.path.exists('.git'): | |
sha = get_git_hash()[:7] | |
elif os.path.exists(version_file): | |
try: | |
from version import __version__ | |
sha = __version__.split('+')[-1] | |
except ImportError: | |
raise ImportError('Unable to get git version') | |
else: | |
sha = 'unknown' | |
return sha | |
def write_version_py(): | |
content = """# GENERATED VERSION FILE | |
# TIME: {} | |
__version__ = '{}' | |
__gitsha__ = '{}' | |
version_info = ({}) | |
""" | |
sha = get_hash() | |
with open('./basicsr/VERSION', 'r') as f: | |
SHORT_VERSION = f.read().strip() | |
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) | |
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) | |
with open(version_file, 'w') as f: | |
f.write(version_file_str) | |
def get_version(): | |
with open(version_file, 'r') as f: | |
exec(compile(f.read(), version_file, 'exec')) | |
return locals()['__version__'] | |
def make_cuda_ext(name, module, sources, sources_cuda=None): | |
if sources_cuda is None: | |
sources_cuda = [] | |
define_macros = [] | |
extra_compile_args = {'cxx': []} | |
# if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': | |
if gpu_is_available or os.getenv('FORCE_CUDA', '0') == '1': | |
define_macros += [('WITH_CUDA', None)] | |
extension = CUDAExtension | |
extra_compile_args['nvcc'] = [ | |
'-D__CUDA_NO_HALF_OPERATORS__', | |
'-D__CUDA_NO_HALF_CONVERSIONS__', | |
'-D__CUDA_NO_HALF2_OPERATORS__', | |
] | |
sources += sources_cuda | |
else: | |
print(f'Compiling {name} without CUDA') | |
extension = CppExtension | |
return extension( | |
name=f'{module}.{name}', | |
sources=[os.path.join(*module.split('.'), p) for p in sources], | |
define_macros=define_macros, | |
extra_compile_args=extra_compile_args) | |
def get_requirements(filename='requirements.txt'): | |
with open(os.path.join('', filename), 'r') as f: | |
requires = [line.replace('\n', '') for line in f.readlines()] | |
return requires | |
if __name__ == '__main__': | |
if '--cuda_ext' in sys.argv: | |
ext_modules = [ | |
make_cuda_ext( | |
name='deform_conv_ext', | |
module='ops.dcn', | |
sources=['src/deform_conv_ext.cpp'], | |
sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']), | |
make_cuda_ext( | |
name='fused_act_ext', | |
module='ops.fused_act', | |
sources=['src/fused_bias_act.cpp'], | |
sources_cuda=['src/fused_bias_act_kernel.cu']), | |
make_cuda_ext( | |
name='upfirdn2d_ext', | |
module='ops.upfirdn2d', | |
sources=['src/upfirdn2d.cpp'], | |
sources_cuda=['src/upfirdn2d_kernel.cu']), | |
] | |
sys.argv.remove('--cuda_ext') | |
else: | |
ext_modules = [] | |
write_version_py() | |
setup( | |
name='basicsr', | |
version=get_version(), | |
description='Open Source Image and Video Super-Resolution Toolbox', | |
long_description=readme(), | |
long_description_content_type='text/markdown', | |
author='Xintao Wang', | |
author_email='[email protected]', | |
keywords='computer vision, restoration, super resolution', | |
url='https://github.com/xinntao/BasicSR', | |
include_package_data=True, | |
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), | |
classifiers=[ | |
'Development Status :: 4 - Beta', | |
'License :: OSI Approved :: Apache Software License', | |
'Operating System :: OS Independent', | |
'Programming Language :: Python :: 3', | |
'Programming Language :: Python :: 3.7', | |
'Programming Language :: Python :: 3.8', | |
], | |
license='Apache License 2.0', | |
setup_requires=['cython', 'numpy'], | |
install_requires=get_requirements(), | |
ext_modules=ext_modules, | |
cmdclass={'build_ext': BuildExtension}, | |
zip_safe=False) | |