|
from typing import Optional |
|
|
|
BLOCKS = { |
|
'content': ['unet.up_blocks.0.attentions.0'], |
|
'style': ['unet.up_blocks.0.attentions.1'], |
|
} |
|
|
|
|
|
def is_belong_to_blocks(key, blocks): |
|
try: |
|
for g in blocks: |
|
if g in key: |
|
return True |
|
return False |
|
except Exception as e: |
|
raise type(e)(f'failed to is_belong_to_block, due to: {e}') |
|
|
|
|
|
def filter_lora(state_dict, blocks_): |
|
try: |
|
return {k: v for k, v in state_dict.items() if is_belong_to_blocks(k, blocks_)} |
|
except Exception as e: |
|
raise type(e)(f'failed to filter_lora, due to: {e}') |
|
|
|
|
|
def scale_lora(state_dict, alpha): |
|
try: |
|
return {k: v * alpha for k, v in state_dict.items()} |
|
except Exception as e: |
|
raise type(e)(f'failed to scale_lora, due to: {e}') |
|
|
|
|
|
def get_target_modules(unet, blocks=None): |
|
try: |
|
if not blocks: |
|
blocks = [('.').join(blk.split('.')[1:]) for blk in BLOCKS['content'] + BLOCKS['style']] |
|
|
|
attns = [attn_processor_name.rsplit('.', 1)[0] for attn_processor_name, _ in unet.attn_processors.items() if |
|
is_belong_to_blocks(attn_processor_name, blocks)] |
|
|
|
target_modules = [f'{attn}.{mat}' for mat in ["to_k", "to_q", "to_v", "to_out.0"] for attn in attns] |
|
return target_modules |
|
except Exception as e: |
|
raise type(e)(f'failed to get_target_modules, due to: {e}') |
|
|
|
|
|
|