Commit 
							
							·
						
						40d1ba9
	
1
								Parent(s):
							
							d7c590b
								
up
Browse files- convert_flax_to_pt.py +2 -1
 - mass_open_controlnet_pr.sh +6 -0
 - model_ids.txt +0 -0
 
    	
        convert_flax_to_pt.py
    CHANGED
    
    | 
         @@ -2,6 +2,7 @@ import argparse 
     | 
|
| 2 | 
         
             
            import json
         
     | 
| 3 | 
         
             
            import os
         
     | 
| 4 | 
         
             
            import shutil
         
     | 
| 
         | 
|
| 5 | 
         
             
            import torch
         
     | 
| 6 | 
         
             
            from tempfile import TemporaryDirectory
         
     | 
| 7 | 
         
             
            from typing import List, Optional
         
     | 
| 
         @@ -18,7 +19,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["Commi 
     | 
|
| 18 | 
         
             
                is_sd = "model_index.json" in filenames
         
     | 
| 19 | 
         | 
| 20 | 
         
             
                if is_sd:
         
     | 
| 21 | 
         
            -
                    model = StableDiffusionPipeline.from_pretrained(model_id, from_flax=True)
         
     | 
| 22 | 
         
             
                else:
         
     | 
| 23 | 
         
             
                    model = ControlNetModel.from_pretrained(model_id, from_flax=True)
         
     | 
| 24 | 
         | 
| 
         | 
|
| 2 | 
         
             
            import json
         
     | 
| 3 | 
         
             
            import os
         
     | 
| 4 | 
         
             
            import shutil
         
     | 
| 5 | 
         
            +
            from diffusers.pipelines.stable_diffusion import safety_checker
         
     | 
| 6 | 
         
             
            import torch
         
     | 
| 7 | 
         
             
            from tempfile import TemporaryDirectory
         
     | 
| 8 | 
         
             
            from typing import List, Optional
         
     | 
| 
         | 
|
| 19 | 
         
             
                is_sd = "model_index.json" in filenames
         
     | 
| 20 | 
         | 
| 21 | 
         
             
                if is_sd:
         
     | 
| 22 | 
         
            +
                    model = StableDiffusionPipeline.from_pretrained(model_id, from_flax=True, safety_checker=None)
         
     | 
| 23 | 
         
             
                else:
         
     | 
| 24 | 
         
             
                    model = ControlNetModel.from_pretrained(model_id, from_flax=True)
         
     | 
| 25 | 
         | 
    	
        mass_open_controlnet_pr.sh
    ADDED
    
    | 
         @@ -0,0 +1,6 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/usr/bin/env bash
         
     | 
| 2 | 
         
            +
            while read p; do
         
     | 
| 3 | 
         
            +
            	echo "-------------------------------"
         
     | 
| 4 | 
         
            +
            	echo "Open PR for $p"
         
     | 
| 5 | 
         
            +
            	python convert_flax_to_pt.py $p
         
     | 
| 6 | 
         
            +
            done
         
     | 
    	
        model_ids.txt
    CHANGED
    
    | 
         The diff for this file is too large to render. 
		See raw diff 
     | 
| 
         |