Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	
		hpc-yekin
		
	commited on
		
		
					Commit 
							
							·
						
						92e0882
	
1
								Parent(s):
							
							4aa0b3a
								
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- AlphaCLIP/.gitignore +12 -0
 - AlphaCLIP/LICENSE +201 -0
 - AlphaCLIP/MANIFEST.in +1 -0
 - AlphaCLIP/alpha_clip/__init__.py +1 -0
 - AlphaCLIP/alpha_clip/alpha_clip.py +250 -0
 - AlphaCLIP/alpha_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
 - AlphaCLIP/alpha_clip/model.py +598 -0
 - AlphaCLIP/alpha_clip/simple_tokenizer.py +132 -0
 - AlphaCLIP/eval/README.md +6 -0
 - AlphaCLIP/eval/imagenet_s_zs_test/.gitignore +2 -0
 - AlphaCLIP/eval/imagenet_s_zs_test/README.md +21 -0
 - AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s.py +149 -0
 - AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s_zs_test.py +66 -0
 - AlphaCLIP/eval/rec_zs_test/LICENSE.md +201 -0
 - AlphaCLIP/eval/rec_zs_test/README.md +74 -0
 - AlphaCLIP/eval/rec_zs_test/cache/.gitkeep +0 -0
 - AlphaCLIP/eval/rec_zs_test/cal_acc.py +21 -0
 - AlphaCLIP/eval/rec_zs_test/ckpt/.gitkeep +0 -0
 - AlphaCLIP/eval/rec_zs_test/data/.gitkeep +0 -0
 - AlphaCLIP/eval/rec_zs_test/entity_extraction.py +142 -0
 - AlphaCLIP/eval/rec_zs_test/executor.py +401 -0
 - AlphaCLIP/eval/rec_zs_test/generic_clip_pairs.py +107 -0
 - AlphaCLIP/eval/rec_zs_test/heuristics.py +68 -0
 - AlphaCLIP/eval/rec_zs_test/interpreter.py +212 -0
 - AlphaCLIP/eval/rec_zs_test/lattice.py +70 -0
 - AlphaCLIP/eval/rec_zs_test/main.py +200 -0
 - AlphaCLIP/eval/rec_zs_test/methods/__init__.py +3 -0
 - AlphaCLIP/eval/rec_zs_test/methods/baseline.py +57 -0
 - AlphaCLIP/eval/rec_zs_test/methods/parse.py +239 -0
 - AlphaCLIP/eval/rec_zs_test/methods/random_method.py +30 -0
 - AlphaCLIP/eval/rec_zs_test/methods/ref_method.py +13 -0
 - AlphaCLIP/eval/rec_zs_test/output/.gitkeep +0 -0
 - AlphaCLIP/eval/rec_zs_test/requirements.txt +53 -0
 - AlphaCLIP/eval/rec_zs_test/run.sh +1 -0
 - AlphaCLIP/eval/rec_zs_test/run_multi_gpus.sh +15 -0
 - AlphaCLIP/hubconf.py +42 -0
 - AlphaCLIP/requirements.txt +5 -0
 - AlphaCLIP/setup.py +21 -0
 - README.md +1 -1
 - app.py +113 -0
 - clip_l14_grit+mim_fultune_6xe.pth +3 -0
 - config/inference_config.yaml +16 -0
 - image_encoder/config.json +23 -0
 - image_encoder/pytorch_model.bin +3 -0
 - ip-adapter_sd15.bin +3 -0
 - model.safetensors +3 -0
 - model/__init__.py +5 -0
 - model/attention_processor.py +189 -0
 - model/clip_away.py +280 -0
 - model/resampler.py +158 -0
 
    	
        AlphaCLIP/.gitignore
    ADDED
    
    | 
         @@ -0,0 +1,12 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            __pycache__/
         
     | 
| 2 | 
         
            +
            *.py[cod]
         
     | 
| 3 | 
         
            +
            *$py.class
         
     | 
| 4 | 
         
            +
            *.egg-info
         
     | 
| 5 | 
         
            +
            .pytest_cache
         
     | 
| 6 | 
         
            +
            .ipynb_checkpoints
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            thumbs.db
         
     | 
| 9 | 
         
            +
            .DS_Store
         
     | 
| 10 | 
         
            +
            .idea
         
     | 
| 11 | 
         
            +
            checkpoints/*
         
     | 
| 12 | 
         
            +
            *.pth
         
     | 
    	
        AlphaCLIP/LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,201 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
                                             Apache License
         
     | 
| 2 | 
         
            +
                                       Version 2.0, January 2004
         
     | 
| 3 | 
         
            +
                                    http://www.apache.org/licenses/
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
               1. Definitions.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         
     | 
| 10 | 
         
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         
     | 
| 13 | 
         
            +
                  the copyright owner that is granting the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         
     | 
| 16 | 
         
            +
                  other entities that control, are controlled by, or are under common
         
     | 
| 17 | 
         
            +
                  control with that entity. For the purposes of this definition,
         
     | 
| 18 | 
         
            +
                  "control" means (i) the power, direct or indirect, to cause the
         
     | 
| 19 | 
         
            +
                  direction or management of such entity, whether by contract or
         
     | 
| 20 | 
         
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         
     | 
| 21 | 
         
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         
     | 
| 24 | 
         
            +
                  exercising permissions granted by this License.
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                  "Source" form shall mean the preferred form for making modifications,
         
     | 
| 27 | 
         
            +
                  including but not limited to software source code, documentation
         
     | 
| 28 | 
         
            +
                  source, and configuration files.
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                  "Object" form shall mean any form resulting from mechanical
         
     | 
| 31 | 
         
            +
                  transformation or translation of a Source form, including but
         
     | 
| 32 | 
         
            +
                  not limited to compiled object code, generated documentation,
         
     | 
| 33 | 
         
            +
                  and conversions to other media types.
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                  "Work" shall mean the work of authorship, whether in Source or
         
     | 
| 36 | 
         
            +
                  Object form, made available under the License, as indicated by a
         
     | 
| 37 | 
         
            +
                  copyright notice that is included in or attached to the work
         
     | 
| 38 | 
         
            +
                  (an example is provided in the Appendix below).
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         
     | 
| 41 | 
         
            +
                  form, that is based on (or derived from) the Work and for which the
         
     | 
| 42 | 
         
            +
                  editorial revisions, annotations, elaborations, or other modifications
         
     | 
| 43 | 
         
            +
                  represent, as a whole, an original work of authorship. For the purposes
         
     | 
| 44 | 
         
            +
                  of this License, Derivative Works shall not include works that remain
         
     | 
| 45 | 
         
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         
     | 
| 46 | 
         
            +
                  the Work and Derivative Works thereof.
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                  "Contribution" shall mean any work of authorship, including
         
     | 
| 49 | 
         
            +
                  the original version of the Work and any modifications or additions
         
     | 
| 50 | 
         
            +
                  to that Work or Derivative Works thereof, that is intentionally
         
     | 
| 51 | 
         
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         
     | 
| 52 | 
         
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         
     | 
| 53 | 
         
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         
     | 
| 54 | 
         
            +
                  means any form of electronic, verbal, or written communication sent
         
     | 
| 55 | 
         
            +
                  to the Licensor or its representatives, including but not limited to
         
     | 
| 56 | 
         
            +
                  communication on electronic mailing lists, source code control systems,
         
     | 
| 57 | 
         
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         
     | 
| 58 | 
         
            +
                  Licensor for the purpose of discussing and improving the Work, but
         
     | 
| 59 | 
         
            +
                  excluding communication that is conspicuously marked or otherwise
         
     | 
| 60 | 
         
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         
     | 
| 63 | 
         
            +
                  on behalf of whom a Contribution has been received by Licensor and
         
     | 
| 64 | 
         
            +
                  subsequently incorporated within the Work.
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         
     | 
| 67 | 
         
            +
                  this License, each Contributor hereby grants to You a perpetual,
         
     | 
| 68 | 
         
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         
     | 
| 69 | 
         
            +
                  copyright license to reproduce, prepare Derivative Works of,
         
     | 
| 70 | 
         
            +
                  publicly display, publicly perform, sublicense, and distribute the
         
     | 
| 71 | 
         
            +
                  Work and such Derivative Works in Source or Object form.
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         
     | 
| 74 | 
         
            +
                  this License, each Contributor hereby grants to You a perpetual,
         
     | 
| 75 | 
         
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         
     | 
| 76 | 
         
            +
                  (except as stated in this section) patent license to make, have made,
         
     | 
| 77 | 
         
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         
     | 
| 78 | 
         
            +
                  where such license applies only to those patent claims licensable
         
     | 
| 79 | 
         
            +
                  by such Contributor that are necessarily infringed by their
         
     | 
| 80 | 
         
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         
     | 
| 81 | 
         
            +
                  with the Work to which such Contribution(s) was submitted. If You
         
     | 
| 82 | 
         
            +
                  institute patent litigation against any entity (including a
         
     | 
| 83 | 
         
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         
     | 
| 84 | 
         
            +
                  or a Contribution incorporated within the Work constitutes direct
         
     | 
| 85 | 
         
            +
                  or contributory patent infringement, then any patent licenses
         
     | 
| 86 | 
         
            +
                  granted to You under this License for that Work shall terminate
         
     | 
| 87 | 
         
            +
                  as of the date such litigation is filed.
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
               4. Redistribution. You may reproduce and distribute copies of the
         
     | 
| 90 | 
         
            +
                  Work or Derivative Works thereof in any medium, with or without
         
     | 
| 91 | 
         
            +
                  modifications, and in Source or Object form, provided that You
         
     | 
| 92 | 
         
            +
                  meet the following conditions:
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                  (a) You must give any other recipients of the Work or
         
     | 
| 95 | 
         
            +
                      Derivative Works a copy of this License; and
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                  (b) You must cause any modified files to carry prominent notices
         
     | 
| 98 | 
         
            +
                      stating that You changed the files; and
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                  (c) You must retain, in the Source form of any Derivative Works
         
     | 
| 101 | 
         
            +
                      that You distribute, all copyright, patent, trademark, and
         
     | 
| 102 | 
         
            +
                      attribution notices from the Source form of the Work,
         
     | 
| 103 | 
         
            +
                      excluding those notices that do not pertain to any part of
         
     | 
| 104 | 
         
            +
                      the Derivative Works; and
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         
     | 
| 107 | 
         
            +
                      distribution, then any Derivative Works that You distribute must
         
     | 
| 108 | 
         
            +
                      include a readable copy of the attribution notices contained
         
     | 
| 109 | 
         
            +
                      within such NOTICE file, excluding those notices that do not
         
     | 
| 110 | 
         
            +
                      pertain to any part of the Derivative Works, in at least one
         
     | 
| 111 | 
         
            +
                      of the following places: within a NOTICE text file distributed
         
     | 
| 112 | 
         
            +
                      as part of the Derivative Works; within the Source form or
         
     | 
| 113 | 
         
            +
                      documentation, if provided along with the Derivative Works; or,
         
     | 
| 114 | 
         
            +
                      within a display generated by the Derivative Works, if and
         
     | 
| 115 | 
         
            +
                      wherever such third-party notices normally appear. The contents
         
     | 
| 116 | 
         
            +
                      of the NOTICE file are for informational purposes only and
         
     | 
| 117 | 
         
            +
                      do not modify the License. You may add Your own attribution
         
     | 
| 118 | 
         
            +
                      notices within Derivative Works that You distribute, alongside
         
     | 
| 119 | 
         
            +
                      or as an addendum to the NOTICE text from the Work, provided
         
     | 
| 120 | 
         
            +
                      that such additional attribution notices cannot be construed
         
     | 
| 121 | 
         
            +
                      as modifying the License.
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                  You may add Your own copyright statement to Your modifications and
         
     | 
| 124 | 
         
            +
                  may provide additional or different license terms and conditions
         
     | 
| 125 | 
         
            +
                  for use, reproduction, or distribution of Your modifications, or
         
     | 
| 126 | 
         
            +
                  for any such Derivative Works as a whole, provided Your use,
         
     | 
| 127 | 
         
            +
                  reproduction, and distribution of the Work otherwise complies with
         
     | 
| 128 | 
         
            +
                  the conditions stated in this License.
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         
     | 
| 131 | 
         
            +
                  any Contribution intentionally submitted for inclusion in the Work
         
     | 
| 132 | 
         
            +
                  by You to the Licensor shall be under the terms and conditions of
         
     | 
| 133 | 
         
            +
                  this License, without any additional terms or conditions.
         
     | 
| 134 | 
         
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         
     | 
| 135 | 
         
            +
                  the terms of any separate license agreement you may have executed
         
     | 
| 136 | 
         
            +
                  with Licensor regarding such Contributions.
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
               6. Trademarks. This License does not grant permission to use the trade
         
     | 
| 139 | 
         
            +
                  names, trademarks, service marks, or product names of the Licensor,
         
     | 
| 140 | 
         
            +
                  except as required for reasonable and customary use in describing the
         
     | 
| 141 | 
         
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         
     | 
| 144 | 
         
            +
                  agreed to in writing, Licensor provides the Work (and each
         
     | 
| 145 | 
         
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         
     | 
| 146 | 
         
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 147 | 
         
            +
                  implied, including, without limitation, any warranties or conditions
         
     | 
| 148 | 
         
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         
     | 
| 149 | 
         
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         
     | 
| 150 | 
         
            +
                  appropriateness of using or redistributing the Work and assume any
         
     | 
| 151 | 
         
            +
                  risks associated with Your exercise of permissions under this License.
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
               8. Limitation of Liability. In no event and under no legal theory,
         
     | 
| 154 | 
         
            +
                  whether in tort (including negligence), contract, or otherwise,
         
     | 
| 155 | 
         
            +
                  unless required by applicable law (such as deliberate and grossly
         
     | 
| 156 | 
         
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         
     | 
| 157 | 
         
            +
                  liable to You for damages, including any direct, indirect, special,
         
     | 
| 158 | 
         
            +
                  incidental, or consequential damages of any character arising as a
         
     | 
| 159 | 
         
            +
                  result of this License or out of the use or inability to use the
         
     | 
| 160 | 
         
            +
                  Work (including but not limited to damages for loss of goodwill,
         
     | 
| 161 | 
         
            +
                  work stoppage, computer failure or malfunction, or any and all
         
     | 
| 162 | 
         
            +
                  other commercial damages or losses), even if such Contributor
         
     | 
| 163 | 
         
            +
                  has been advised of the possibility of such damages.
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         
     | 
| 166 | 
         
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         
     | 
| 167 | 
         
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         
     | 
| 168 | 
         
            +
                  or other liability obligations and/or rights consistent with this
         
     | 
| 169 | 
         
            +
                  License. However, in accepting such obligations, You may act only
         
     | 
| 170 | 
         
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         
     | 
| 171 | 
         
            +
                  of any other Contributor, and only if You agree to indemnify,
         
     | 
| 172 | 
         
            +
                  defend, and hold each Contributor harmless for any liability
         
     | 
| 173 | 
         
            +
                  incurred by, or claims asserted against, such Contributor by reason
         
     | 
| 174 | 
         
            +
                  of your accepting any such warranty or additional liability.
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
               END OF TERMS AND CONDITIONS
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
               APPENDIX: How to apply the Apache License to your work.
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                  To apply the Apache License to your work, attach the following
         
     | 
| 181 | 
         
            +
                  boilerplate notice, with the fields enclosed by brackets "[]"
         
     | 
| 182 | 
         
            +
                  replaced with your own identifying information. (Don't include
         
     | 
| 183 | 
         
            +
                  the brackets!)  The text should be enclosed in the appropriate
         
     | 
| 184 | 
         
            +
                  comment syntax for the file format. We also recommend that a
         
     | 
| 185 | 
         
            +
                  file or class name and description of purpose be included on the
         
     | 
| 186 | 
         
            +
                  same "printed page" as the copyright notice for easier
         
     | 
| 187 | 
         
            +
                  identification within third-party archives.
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
               Copyright [Zeyi Sun] [name of copyright owner]
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 192 | 
         
            +
               you may not use this file except in compliance with the License.
         
     | 
| 193 | 
         
            +
               You may obtain a copy of the License at
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
               Unless required by applicable law or agreed to in writing, software
         
     | 
| 198 | 
         
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 199 | 
         
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 200 | 
         
            +
               See the License for the specific language governing permissions and
         
     | 
| 201 | 
         
            +
               limitations under the License.
         
     | 
    	
        AlphaCLIP/MANIFEST.in
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            include alpha_clip/bpe_simple_vocab_16e6.txt.gz
         
     | 
    	
        AlphaCLIP/alpha_clip/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .alpha_clip import *
         
     | 
    	
        AlphaCLIP/alpha_clip/alpha_clip.py
    ADDED
    
    | 
         @@ -0,0 +1,250 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import hashlib
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            import urllib
         
     | 
| 4 | 
         
            +
            import warnings
         
     | 
| 5 | 
         
            +
            from typing import Any, Union, List
         
     | 
| 6 | 
         
            +
            from pkg_resources import packaging
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            from PIL import Image
         
     | 
| 10 | 
         
            +
            from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
         
     | 
| 11 | 
         
            +
            from tqdm import tqdm
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from .model import build_model
         
     | 
| 14 | 
         
            +
            from .simple_tokenizer import SimpleTokenizer as _Tokenizer
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            try:
         
     | 
| 17 | 
         
            +
                from torchvision.transforms import InterpolationMode
         
     | 
| 18 | 
         
            +
                BICUBIC = InterpolationMode.BICUBIC
         
     | 
| 19 | 
         
            +
            except ImportError:
         
     | 
| 20 | 
         
            +
                BICUBIC = Image.BICUBIC
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
         
     | 
| 24 | 
         
            +
                warnings.warn("PyTorch version 1.7.1 or higher is recommended")
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            __all__ = ["available_models", "load", "tokenize"]
         
     | 
| 28 | 
         
            +
            _tokenizer = _Tokenizer()
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            _MODELS = {
         
     | 
| 31 | 
         
            +
                "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
         
     | 
| 32 | 
         
            +
                "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
         
     | 
| 33 | 
         
            +
                "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
         
     | 
| 34 | 
         
            +
                "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
         
     | 
| 35 | 
         
            +
                "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
         
     | 
| 36 | 
         
            +
                "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
         
     | 
| 37 | 
         
            +
                "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
         
     | 
| 38 | 
         
            +
                "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
         
     | 
| 39 | 
         
            +
                "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
         
     | 
| 40 | 
         
            +
            }
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            def _download(url: str, root: str):
         
     | 
| 44 | 
         
            +
                os.makedirs(root, exist_ok=True)
         
     | 
| 45 | 
         
            +
                filename = os.path.basename(url)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                expected_sha256 = url.split("/")[-2]
         
     | 
| 48 | 
         
            +
                download_target = os.path.join(root, filename)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                if os.path.exists(download_target) and not os.path.isfile(download_target):
         
     | 
| 51 | 
         
            +
                    raise RuntimeError(f"{download_target} exists and is not a regular file")
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                if os.path.isfile(download_target):
         
     | 
| 54 | 
         
            +
                    if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
         
     | 
| 55 | 
         
            +
                        return download_target
         
     | 
| 56 | 
         
            +
                    else:
         
     | 
| 57 | 
         
            +
                        warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
         
     | 
| 60 | 
         
            +
                    with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
         
     | 
| 61 | 
         
            +
                        while True:
         
     | 
| 62 | 
         
            +
                            buffer = source.read(8192)
         
     | 
| 63 | 
         
            +
                            if not buffer:
         
     | 
| 64 | 
         
            +
                                break
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                            output.write(buffer)
         
     | 
| 67 | 
         
            +
                            loop.update(len(buffer))
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
         
     | 
| 70 | 
         
            +
                    raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                return download_target
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            def _convert_image_to_rgb(image):
         
     | 
| 76 | 
         
            +
                return image.convert("RGB")
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            def _transform(n_px):
         
     | 
| 80 | 
         
            +
                return Compose([
         
     | 
| 81 | 
         
            +
                    Resize(n_px, interpolation=BICUBIC),
         
     | 
| 82 | 
         
            +
                    CenterCrop(n_px),
         
     | 
| 83 | 
         
            +
                    _convert_image_to_rgb,
         
     | 
| 84 | 
         
            +
                    ToTensor(),
         
     | 
| 85 | 
         
            +
                    Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
         
     | 
| 86 | 
         
            +
                ])
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
            def available_models() -> List[str]:
         
     | 
| 90 | 
         
            +
                """Returns the names of available CLIP models"""
         
     | 
| 91 | 
         
            +
                return list(_MODELS.keys())
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
            def load(name: str, alpha_vision_ckpt_pth="None", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, lora_adapt=False, rank=16):
         
     | 
| 95 | 
         
            +
                """Load a CLIP model
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                Parameters
         
     | 
| 98 | 
         
            +
                ----------
         
     | 
| 99 | 
         
            +
                name : str
         
     | 
| 100 | 
         
            +
                    A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                alpha_vision_ckpt_pth: str
         
     | 
| 103 | 
         
            +
                    only changed when inferencing model instead of training
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                device : Union[str, torch.device]
         
     | 
| 106 | 
         
            +
                    The device to put the loaded model
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                jit : bool
         
     | 
| 109 | 
         
            +
                    Whether to load the optimized JIT model or more hackable non-JIT model (default).
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                download_root: str
         
     | 
| 112 | 
         
            +
                    path to download the model files; by default, it uses "~/.cache/clip"
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                Returns
         
     | 
| 115 | 
         
            +
                -------
         
     | 
| 116 | 
         
            +
                model : torch.nn.Module
         
     | 
| 117 | 
         
            +
                    The CLIP model
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                preprocess : Callable[[PIL.Image], torch.Tensor]
         
     | 
| 120 | 
         
            +
                    A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
         
     | 
| 121 | 
         
            +
                """
         
     | 
| 122 | 
         
            +
                if name in _MODELS:
         
     | 
| 123 | 
         
            +
                    model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
         
     | 
| 124 | 
         
            +
                elif os.path.isfile(name):
         
     | 
| 125 | 
         
            +
                    model_path = name
         
     | 
| 126 | 
         
            +
                else:
         
     | 
| 127 | 
         
            +
                    raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                with open(model_path, 'rb') as opened_file:
         
     | 
| 130 | 
         
            +
                    try:
         
     | 
| 131 | 
         
            +
                        # loading JIT archive
         
     | 
| 132 | 
         
            +
                        model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
         
     | 
| 133 | 
         
            +
                        state_dict = None
         
     | 
| 134 | 
         
            +
                    except RuntimeError:
         
     | 
| 135 | 
         
            +
                        # loading saved state dict
         
     | 
| 136 | 
         
            +
                        if jit:
         
     | 
| 137 | 
         
            +
                            warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
         
     | 
| 138 | 
         
            +
                            jit = False
         
     | 
| 139 | 
         
            +
                        state_dict = torch.load(opened_file, map_location="cpu")
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                if not jit:
         
     | 
| 142 | 
         
            +
                    model = build_model(state_dict or model.state_dict(), lora_adapt=lora_adapt, rank=rank).to(device)
         
     | 
| 143 | 
         
            +
                    if str(device) == "cpu":
         
     | 
| 144 | 
         
            +
                        model.float()
         
     | 
| 145 | 
         
            +
                    if alpha_vision_ckpt_pth != "None":
         
     | 
| 146 | 
         
            +
                        model.visual.load_state_dict(torch.load(alpha_vision_ckpt_pth))
         
     | 
| 147 | 
         
            +
                        model.eval() # merge lora params if exists (for inference only)
         
     | 
| 148 | 
         
            +
                    return model, _transform(model.visual.input_resolution)
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                # patch the device names
         
     | 
| 151 | 
         
            +
                device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
         
     | 
| 152 | 
         
            +
                device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                def _node_get(node: torch._C.Node, key: str):
         
     | 
| 155 | 
         
            +
                    """Gets attributes of a node which is polymorphic over return type.
         
     | 
| 156 | 
         
            +
                    
         
     | 
| 157 | 
         
            +
                    From https://github.com/pytorch/pytorch/pull/82628
         
     | 
| 158 | 
         
            +
                    """
         
     | 
| 159 | 
         
            +
                    sel = node.kindOf(key)
         
     | 
| 160 | 
         
            +
                    return getattr(node, sel)(key)
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                def patch_device(module):
         
     | 
| 163 | 
         
            +
                    try:
         
     | 
| 164 | 
         
            +
                        graphs = [module.graph] if hasattr(module, "graph") else []
         
     | 
| 165 | 
         
            +
                    except RuntimeError:
         
     | 
| 166 | 
         
            +
                        graphs = []
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                    if hasattr(module, "forward1"):
         
     | 
| 169 | 
         
            +
                        graphs.append(module.forward1.graph)
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                    for graph in graphs:
         
     | 
| 172 | 
         
            +
                        for node in graph.findAllNodes("prim::Constant"):
         
     | 
| 173 | 
         
            +
                            if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
         
     | 
| 174 | 
         
            +
                                node.copyAttributes(device_node)
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                model.apply(patch_device)
         
     | 
| 177 | 
         
            +
                patch_device(model.encode_image)
         
     | 
| 178 | 
         
            +
                patch_device(model.encode_text)
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                # patch dtype to float32 on CPU
         
     | 
| 181 | 
         
            +
                if str(device) == "cpu":
         
     | 
| 182 | 
         
            +
                    float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
         
     | 
| 183 | 
         
            +
                    float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
         
     | 
| 184 | 
         
            +
                    float_node = float_input.node()
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    def patch_float(module):
         
     | 
| 187 | 
         
            +
                        try:
         
     | 
| 188 | 
         
            +
                            graphs = [module.graph] if hasattr(module, "graph") else []
         
     | 
| 189 | 
         
            +
                        except RuntimeError:
         
     | 
| 190 | 
         
            +
                            graphs = []
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                        if hasattr(module, "forward1"):
         
     | 
| 193 | 
         
            +
                            graphs.append(module.forward1.graph)
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                        for graph in graphs:
         
     | 
| 196 | 
         
            +
                            for node in graph.findAllNodes("aten::to"):
         
     | 
| 197 | 
         
            +
                                inputs = list(node.inputs())
         
     | 
| 198 | 
         
            +
                                for i in [1, 2]:  # dtype can be the second or third argument to aten::to()
         
     | 
| 199 | 
         
            +
                                    if _node_get(inputs[i].node(), "value") == 5:
         
     | 
| 200 | 
         
            +
                                        inputs[i].node().copyAttributes(float_node)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    model.apply(patch_float)
         
     | 
| 203 | 
         
            +
                    patch_float(model.encode_image)
         
     | 
| 204 | 
         
            +
                    patch_float(model.encode_text)
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                    model.float()
         
     | 
| 207 | 
         
            +
                return model, _transform(model.input_resolution.item())
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
            def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = True) -> Union[torch.IntTensor, torch.LongTensor]:
         
     | 
| 211 | 
         
            +
                """
         
     | 
| 212 | 
         
            +
                Returns the tokenized representation of given input string(s)
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                Parameters
         
     | 
| 215 | 
         
            +
                ----------
         
     | 
| 216 | 
         
            +
                texts : Union[str, List[str]]
         
     | 
| 217 | 
         
            +
                    An input string or a list of input strings to tokenize
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                context_length : int
         
     | 
| 220 | 
         
            +
                    The context length to use; all CLIP models use 77 as the context length
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                truncate: bool
         
     | 
| 223 | 
         
            +
                    Whether to truncate the text in case its encoding is longer than the context length
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                Returns
         
     | 
| 226 | 
         
            +
                -------
         
     | 
| 227 | 
         
            +
                A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
         
     | 
| 228 | 
         
            +
                We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
         
     | 
| 229 | 
         
            +
                """
         
     | 
| 230 | 
         
            +
                if isinstance(texts, str):
         
     | 
| 231 | 
         
            +
                    texts = [texts]
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                sot_token = _tokenizer.encoder["<|startoftext|>"]
         
     | 
| 234 | 
         
            +
                eot_token = _tokenizer.encoder["<|endoftext|>"]
         
     | 
| 235 | 
         
            +
                all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
         
     | 
| 236 | 
         
            +
                if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
         
     | 
| 237 | 
         
            +
                    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
         
     | 
| 238 | 
         
            +
                else:
         
     | 
| 239 | 
         
            +
                    result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                for i, tokens in enumerate(all_tokens):
         
     | 
| 242 | 
         
            +
                    if len(tokens) > context_length:
         
     | 
| 243 | 
         
            +
                        if truncate:
         
     | 
| 244 | 
         
            +
                            tokens = tokens[:context_length]
         
     | 
| 245 | 
         
            +
                            tokens[-1] = eot_token
         
     | 
| 246 | 
         
            +
                        else:
         
     | 
| 247 | 
         
            +
                            raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
         
     | 
| 248 | 
         
            +
                    result[i, :len(tokens)] = torch.tensor(tokens)
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                return result
         
     | 
    	
        AlphaCLIP/alpha_clip/bpe_simple_vocab_16e6.txt.gz
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
         
     | 
| 3 | 
         
            +
            size 1356917
         
     | 
    	
        AlphaCLIP/alpha_clip/model.py
    ADDED
    
    | 
         @@ -0,0 +1,598 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from collections import OrderedDict
         
     | 
| 2 | 
         
            +
            from typing import Tuple, Union
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 7 | 
         
            +
            from torch import nn
         
     | 
| 8 | 
         
            +
            import loralib as lora
         
     | 
| 9 | 
         
            +
            import math
         
     | 
| 10 | 
         
            +
            import collections
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class Bottleneck(nn.Module):
         
     | 
| 13 | 
         
            +
                expansion = 4
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                def __init__(self, inplanes, planes, stride=1):
         
     | 
| 16 | 
         
            +
                    super().__init__()
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                    # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
         
     | 
| 19 | 
         
            +
                    self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
         
     | 
| 20 | 
         
            +
                    self.bn1 = nn.BatchNorm2d(planes)
         
     | 
| 21 | 
         
            +
                    self.relu1 = nn.ReLU(inplace=True)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                    self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
         
     | 
| 24 | 
         
            +
                    self.bn2 = nn.BatchNorm2d(planes)
         
     | 
| 25 | 
         
            +
                    self.relu2 = nn.ReLU(inplace=True)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
         
     | 
| 30 | 
         
            +
                    self.bn3 = nn.BatchNorm2d(planes * self.expansion)
         
     | 
| 31 | 
         
            +
                    self.relu3 = nn.ReLU(inplace=True)
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    self.downsample = None
         
     | 
| 34 | 
         
            +
                    self.stride = stride
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    if stride > 1 or inplanes != planes * Bottleneck.expansion:
         
     | 
| 37 | 
         
            +
                        # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
         
     | 
| 38 | 
         
            +
                        self.downsample = nn.Sequential(OrderedDict([
         
     | 
| 39 | 
         
            +
                            ("-1", nn.AvgPool2d(stride)),
         
     | 
| 40 | 
         
            +
                            ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
         
     | 
| 41 | 
         
            +
                            ("1", nn.BatchNorm2d(planes * self.expansion))
         
     | 
| 42 | 
         
            +
                        ]))
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                def forward(self, x: torch.Tensor):
         
     | 
| 45 | 
         
            +
                    identity = x
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                    out = self.relu1(self.bn1(self.conv1(x)))
         
     | 
| 48 | 
         
            +
                    out = self.relu2(self.bn2(self.conv2(out)))
         
     | 
| 49 | 
         
            +
                    out = self.avgpool(out)
         
     | 
| 50 | 
         
            +
                    out = self.bn3(self.conv3(out))
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    if self.downsample is not None:
         
     | 
| 53 | 
         
            +
                        identity = self.downsample(x)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    out += identity
         
     | 
| 56 | 
         
            +
                    out = self.relu3(out)
         
     | 
| 57 | 
         
            +
                    return out
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            class AttentionPool2d(nn.Module):
         
     | 
| 61 | 
         
            +
                def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
         
     | 
| 62 | 
         
            +
                    super().__init__()
         
     | 
| 63 | 
         
            +
                    self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
         
     | 
| 64 | 
         
            +
                    self.k_proj = nn.Linear(embed_dim, embed_dim)
         
     | 
| 65 | 
         
            +
                    self.q_proj = nn.Linear(embed_dim, embed_dim)
         
     | 
| 66 | 
         
            +
                    self.v_proj = nn.Linear(embed_dim, embed_dim)
         
     | 
| 67 | 
         
            +
                    self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
         
     | 
| 68 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                def forward(self, x):
         
     | 
| 71 | 
         
            +
                    x = x.flatten(start_dim=2).permute(2, 0, 1)  # NCHW -> (HW)NC
         
     | 
| 72 | 
         
            +
                    x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
         
     | 
| 73 | 
         
            +
                    x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
         
     | 
| 74 | 
         
            +
                    x, _ = F.multi_head_attention_forward(
         
     | 
| 75 | 
         
            +
                        query=x[:1], key=x, value=x,
         
     | 
| 76 | 
         
            +
                        embed_dim_to_check=x.shape[-1],
         
     | 
| 77 | 
         
            +
                        num_heads=self.num_heads,
         
     | 
| 78 | 
         
            +
                        q_proj_weight=self.q_proj.weight,
         
     | 
| 79 | 
         
            +
                        k_proj_weight=self.k_proj.weight,
         
     | 
| 80 | 
         
            +
                        v_proj_weight=self.v_proj.weight,
         
     | 
| 81 | 
         
            +
                        in_proj_weight=None,
         
     | 
| 82 | 
         
            +
                        in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
         
     | 
| 83 | 
         
            +
                        bias_k=None,
         
     | 
| 84 | 
         
            +
                        bias_v=None,
         
     | 
| 85 | 
         
            +
                        add_zero_attn=False,
         
     | 
| 86 | 
         
            +
                        dropout_p=0,
         
     | 
| 87 | 
         
            +
                        out_proj_weight=self.c_proj.weight,
         
     | 
| 88 | 
         
            +
                        out_proj_bias=self.c_proj.bias,
         
     | 
| 89 | 
         
            +
                        use_separate_proj_weight=True,
         
     | 
| 90 | 
         
            +
                        training=self.training,
         
     | 
| 91 | 
         
            +
                        need_weights=False
         
     | 
| 92 | 
         
            +
                    )
         
     | 
| 93 | 
         
            +
                    return x.squeeze(0)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            class ModifiedResNet(nn.Module):
         
     | 
| 97 | 
         
            +
                """
         
     | 
| 98 | 
         
            +
                A ResNet class that is similar to torchvision's but contains the following changes:
         
     | 
| 99 | 
         
            +
                - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
         
     | 
| 100 | 
         
            +
                - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
         
     | 
| 101 | 
         
            +
                - The final pooling layer is a QKV attention instead of an average pool
         
     | 
| 102 | 
         
            +
                """
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
         
     | 
| 105 | 
         
            +
                    super().__init__()
         
     | 
| 106 | 
         
            +
                    self.output_dim = output_dim
         
     | 
| 107 | 
         
            +
                    self.input_resolution = input_resolution
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    # the 3-layer stem
         
     | 
| 110 | 
         
            +
                    self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
         
     | 
| 111 | 
         
            +
                    self.conv1_alpha = nn.Conv2d(in_channels=1, out_channels=width // 2, kernel_size=3, stride=2, padding=1, bias=False)
         
     | 
| 112 | 
         
            +
                    self.bn1 = nn.BatchNorm2d(width // 2)
         
     | 
| 113 | 
         
            +
                    self.relu1 = nn.ReLU(inplace=True)
         
     | 
| 114 | 
         
            +
                    self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
         
     | 
| 115 | 
         
            +
                    self.bn2 = nn.BatchNorm2d(width // 2)
         
     | 
| 116 | 
         
            +
                    self.relu2 = nn.ReLU(inplace=True)
         
     | 
| 117 | 
         
            +
                    self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
         
     | 
| 118 | 
         
            +
                    self.bn3 = nn.BatchNorm2d(width)
         
     | 
| 119 | 
         
            +
                    self.relu3 = nn.ReLU(inplace=True)
         
     | 
| 120 | 
         
            +
                    self.avgpool = nn.AvgPool2d(2)
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    # residual layers
         
     | 
| 123 | 
         
            +
                    self._inplanes = width  # this is a *mutable* variable used during construction
         
     | 
| 124 | 
         
            +
                    self.layer1 = self._make_layer(width, layers[0])
         
     | 
| 125 | 
         
            +
                    self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
         
     | 
| 126 | 
         
            +
                    self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
         
     | 
| 127 | 
         
            +
                    self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    embed_dim = width * 32  # the ResNet feature dimension
         
     | 
| 130 | 
         
            +
                    self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                def _make_layer(self, planes, blocks, stride=1):
         
     | 
| 133 | 
         
            +
                    layers = [Bottleneck(self._inplanes, planes, stride)]
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    self._inplanes = planes * Bottleneck.expansion
         
     | 
| 136 | 
         
            +
                    for _ in range(1, blocks):
         
     | 
| 137 | 
         
            +
                        layers.append(Bottleneck(self._inplanes, planes))
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    return nn.Sequential(*layers)
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                def forward(self, x, alpha=None):
         
     | 
| 142 | 
         
            +
                    def stem(x):
         
     | 
| 143 | 
         
            +
                        x = self.relu1(self.bn1(self.conv1(x) + self.conv1_alpha(alpha)))
         
     | 
| 144 | 
         
            +
                        x = self.relu2(self.bn2(self.conv2(x)))
         
     | 
| 145 | 
         
            +
                        x = self.relu3(self.bn3(self.conv3(x)))
         
     | 
| 146 | 
         
            +
                        x = self.avgpool(x)
         
     | 
| 147 | 
         
            +
                        return x
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    x = x.type(self.conv1.weight.dtype)
         
     | 
| 150 | 
         
            +
                    x = stem(x)
         
     | 
| 151 | 
         
            +
                    x = self.layer1(x)
         
     | 
| 152 | 
         
            +
                    x = self.layer2(x)
         
     | 
| 153 | 
         
            +
                    x = self.layer3(x)
         
     | 
| 154 | 
         
            +
                    x = self.layer4(x)
         
     | 
| 155 | 
         
            +
                    x = self.attnpool(x)
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                    return x
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
            class LayerNorm(nn.LayerNorm):
         
     | 
| 161 | 
         
            +
                """Subclass torch's LayerNorm to handle fp16."""
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                def forward(self, x: torch.Tensor):
         
     | 
| 164 | 
         
            +
                    orig_type = x.dtype
         
     | 
| 165 | 
         
            +
                    ret = super().forward(x.type(torch.float32))
         
     | 
| 166 | 
         
            +
                    return ret.type(orig_type)
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
            class QuickGELU(nn.Module):
         
     | 
| 170 | 
         
            +
                def forward(self, x: torch.Tensor):
         
     | 
| 171 | 
         
            +
                    return x * torch.sigmoid(1.702 * x)
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
            class Attention(nn.Module):
         
     | 
| 174 | 
         
            +
                def __init__(
         
     | 
| 175 | 
         
            +
                        self,
         
     | 
| 176 | 
         
            +
                        dim,
         
     | 
| 177 | 
         
            +
                        num_heads=8,
         
     | 
| 178 | 
         
            +
                        qkv_bias=True,
         
     | 
| 179 | 
         
            +
                        scaled_cosine=False,
         
     | 
| 180 | 
         
            +
                        scale_heads=False,
         
     | 
| 181 | 
         
            +
                        logit_scale_max=math.log(1. / 0.01),
         
     | 
| 182 | 
         
            +
                        attn_drop=0.,
         
     | 
| 183 | 
         
            +
                        proj_drop=0.,
         
     | 
| 184 | 
         
            +
                        lora_adapt=False, 
         
     | 
| 185 | 
         
            +
                        rank=16
         
     | 
| 186 | 
         
            +
                ):
         
     | 
| 187 | 
         
            +
                    super().__init__()
         
     | 
| 188 | 
         
            +
                    self.scaled_cosine = scaled_cosine
         
     | 
| 189 | 
         
            +
                    self.scale_heads = scale_heads
         
     | 
| 190 | 
         
            +
                    assert dim % num_heads == 0, 'dim should be divisible by num_heads'
         
     | 
| 191 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 192 | 
         
            +
                    self.head_dim = dim // num_heads
         
     | 
| 193 | 
         
            +
                    self.scale = self.head_dim ** -0.5
         
     | 
| 194 | 
         
            +
                    self.logit_scale_max = logit_scale_max
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                    # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
         
     | 
| 197 | 
         
            +
                    if lora_adapt:
         
     | 
| 198 | 
         
            +
                        print("!!!!!!!!!!using lora for qkv projection!!!!!!!!!!")
         
     | 
| 199 | 
         
            +
                        self.in_proj = lora.MergedLinear(dim, 3*dim, r=rank, enable_lora=[True, False, True])
         
     | 
| 200 | 
         
            +
                    else:
         
     | 
| 201 | 
         
            +
                        self.in_proj = nn.Linear(dim, dim * 3)
         
     | 
| 202 | 
         
            +
                    # self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
         
     | 
| 203 | 
         
            +
                    # if qkv_bias:
         
     | 
| 204 | 
         
            +
                    #     self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
         
     | 
| 205 | 
         
            +
                    # else:
         
     | 
| 206 | 
         
            +
                    #     self.in_proj_bias = None
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    if self.scaled_cosine:
         
     | 
| 209 | 
         
            +
                        self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
         
     | 
| 210 | 
         
            +
                    else:
         
     | 
| 211 | 
         
            +
                        self.logit_scale = None
         
     | 
| 212 | 
         
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         
     | 
| 213 | 
         
            +
                    if self.scale_heads:
         
     | 
| 214 | 
         
            +
                        self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
         
     | 
| 215 | 
         
            +
                    else:
         
     | 
| 216 | 
         
            +
                        self.head_scale = None
         
     | 
| 217 | 
         
            +
                    self.out_proj = nn.Linear(dim, dim) if not lora_adapt else lora.Linear(dim, dim, r=rank)
         
     | 
| 218 | 
         
            +
                    self.out_drop = nn.Dropout(proj_drop)
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                def forward(self, x, attn_mask = None):
         
     | 
| 221 | 
         
            +
                    L, N, C = x.shape
         
     | 
| 222 | 
         
            +
                    q, k, v = self.in_proj(x).chunk(3, dim=-1)
         
     | 
| 223 | 
         
            +
                    q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
         
     | 
| 224 | 
         
            +
                    k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
         
     | 
| 225 | 
         
            +
                    v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                    if self.logit_scale is not None:
         
     | 
| 228 | 
         
            +
                        attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
         
     | 
| 229 | 
         
            +
                        logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
         
     | 
| 230 | 
         
            +
                        attn = attn.view(N, self.num_heads, L, L) * logit_scale
         
     | 
| 231 | 
         
            +
                        attn = attn.view(-1, L, L)
         
     | 
| 232 | 
         
            +
                    else:
         
     | 
| 233 | 
         
            +
                        q = q * self.scale
         
     | 
| 234 | 
         
            +
                        attn = torch.bmm(q, k.transpose(-2, -1))
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                    if attn_mask is not None:
         
     | 
| 237 | 
         
            +
                        if attn_mask.dtype == torch.bool:
         
     | 
| 238 | 
         
            +
                            new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
         
     | 
| 239 | 
         
            +
                            new_attn_mask.masked_fill_(attn_mask, float("-inf"))
         
     | 
| 240 | 
         
            +
                            attn_mask = new_attn_mask
         
     | 
| 241 | 
         
            +
                        attn += attn_mask
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    attn = attn.softmax(dim=-1)
         
     | 
| 244 | 
         
            +
                    attn = self.attn_drop(attn)
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
                    x = torch.bmm(attn, v)
         
     | 
| 247 | 
         
            +
                    if self.head_scale is not None:
         
     | 
| 248 | 
         
            +
                        x = x.view(N, self.num_heads, L, C) * self.head_scale
         
     | 
| 249 | 
         
            +
                        x = x.view(-1, L, C)
         
     | 
| 250 | 
         
            +
                    x = x.transpose(0, 1).reshape(L, N, C)
         
     | 
| 251 | 
         
            +
                    x = self.out_proj(x)
         
     | 
| 252 | 
         
            +
                    x = self.out_drop(x)
         
     | 
| 253 | 
         
            +
                    return x, attn
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
            class CustomResidualAttentionBlock(nn.Module):
         
     | 
| 257 | 
         
            +
                def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, lora_adapt=False, rank=16):
         
     | 
| 258 | 
         
            +
                    super().__init__()
         
     | 
| 259 | 
         
            +
                    
         
     | 
| 260 | 
         
            +
                    self.attn = Attention(d_model, n_head, lora_adapt=lora_adapt, rank=rank)
         
     | 
| 261 | 
         
            +
                    self.ln_1 = LayerNorm(d_model)
         
     | 
| 262 | 
         
            +
                    self.mlp = nn.Sequential(OrderedDict([
         
     | 
| 263 | 
         
            +
                        ("c_fc", nn.Linear(d_model, d_model * 4) if not lora_adapt else lora.Linear(d_model, d_model*4, r=rank)),
         
     | 
| 264 | 
         
            +
                        ("gelu", QuickGELU()),
         
     | 
| 265 | 
         
            +
                        ("c_proj", nn.Linear(d_model * 4, d_model) if not lora_adapt else lora.Linear(d_model*4, d_model, r=rank))
         
     | 
| 266 | 
         
            +
                    ]))
         
     | 
| 267 | 
         
            +
                    self.ln_2 = LayerNorm(d_model)
         
     | 
| 268 | 
         
            +
                    self.attn_mask = attn_mask
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                def attention(self, x: torch.Tensor):
         
     | 
| 271 | 
         
            +
                    self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
         
     | 
| 272 | 
         
            +
                    return self.attn(x, attn_mask=self.attn_mask)
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                def forward(self, x: torch.Tensor, return_attn=False):
         
     | 
| 275 | 
         
            +
                    attn_out, attn = self.attention(self.ln_1(x))
         
     | 
| 276 | 
         
            +
                    x = x + attn_out
         
     | 
| 277 | 
         
            +
                    x = x + self.mlp(self.ln_2(x))
         
     | 
| 278 | 
         
            +
                    if return_attn:
         
     | 
| 279 | 
         
            +
                        return x, attn
         
     | 
| 280 | 
         
            +
                    else:
         
     | 
| 281 | 
         
            +
                        return x
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
            class ResidualAttentionBlock(nn.Module):
         
     | 
| 284 | 
         
            +
                def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
         
     | 
| 285 | 
         
            +
                    super().__init__()
         
     | 
| 286 | 
         
            +
                    
         
     | 
| 287 | 
         
            +
                    self.attn = nn.MultiheadAttention(d_model, n_head)
         
     | 
| 288 | 
         
            +
                    self.ln_1 = LayerNorm(d_model)
         
     | 
| 289 | 
         
            +
                    self.mlp = nn.Sequential(OrderedDict([
         
     | 
| 290 | 
         
            +
                        ("c_fc", nn.Linear(d_model, d_model * 4)),
         
     | 
| 291 | 
         
            +
                        ("gelu", QuickGELU()),
         
     | 
| 292 | 
         
            +
                        ("c_proj", nn.Linear(d_model * 4, d_model))
         
     | 
| 293 | 
         
            +
                    ]))
         
     | 
| 294 | 
         
            +
                    self.ln_2 = LayerNorm(d_model)
         
     | 
| 295 | 
         
            +
                    self.attn_mask = attn_mask
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                def attention(self, x: torch.Tensor):
         
     | 
| 298 | 
         
            +
                    self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
         
     | 
| 299 | 
         
            +
                    return self.attn(x, x, x, attn_mask=self.attn_mask)[0]
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                def forward(self, x: torch.Tensor):
         
     | 
| 302 | 
         
            +
                    x = x + self.attention(self.ln_1(x))
         
     | 
| 303 | 
         
            +
                    x = x + self.mlp(self.ln_2(x))
         
     | 
| 304 | 
         
            +
                    return x
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
            class Transformer(nn.Module):
         
     | 
| 307 | 
         
            +
                def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
         
     | 
| 308 | 
         
            +
                    super().__init__()
         
     | 
| 309 | 
         
            +
                    self.width = width
         
     | 
| 310 | 
         
            +
                    self.layers = layers
         
     | 
| 311 | 
         
            +
                    self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                def forward(self, x: torch.Tensor):
         
     | 
| 314 | 
         
            +
                    return self.resblocks(x)
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
            class CustomTransformer(nn.Module):
         
     | 
| 317 | 
         
            +
                def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, lora_adapt=False, rank=16):
         
     | 
| 318 | 
         
            +
                    super().__init__()
         
     | 
| 319 | 
         
            +
                    self.width = width
         
     | 
| 320 | 
         
            +
                    self.layers = layers
         
     | 
| 321 | 
         
            +
                    self.resblocks = nn.Sequential(*[CustomResidualAttentionBlock(width, heads, attn_mask, lora_adapt=lora_adapt, rank=rank) for _ in range(layers)])
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                def forward(self, x: torch.Tensor, return_attn=False):
         
     | 
| 324 | 
         
            +
                    if return_attn:
         
     | 
| 325 | 
         
            +
                        for i, block in enumerate(self.resblocks):
         
     | 
| 326 | 
         
            +
                            if i == len(self.resblocks) - 1:
         
     | 
| 327 | 
         
            +
                                return block(x, return_attn=True)
         
     | 
| 328 | 
         
            +
                            else:
         
     | 
| 329 | 
         
            +
                                x = block(x)
         
     | 
| 330 | 
         
            +
                        assert False
         
     | 
| 331 | 
         
            +
                    return self.resblocks(x)
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
            class VisionTransformer(nn.Module):
         
     | 
| 334 | 
         
            +
                def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, lora_adapt=False, rank=16):
         
     | 
| 335 | 
         
            +
                    super().__init__()
         
     | 
| 336 | 
         
            +
                    self.input_resolution = input_resolution
         
     | 
| 337 | 
         
            +
                    self.output_dim = output_dim
         
     | 
| 338 | 
         
            +
                    self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
         
     | 
| 339 | 
         
            +
                    self.conv1_alpha = nn.Conv2d(in_channels=1, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
                    scale = width ** -0.5
         
     | 
| 342 | 
         
            +
                    self.class_embedding = nn.Parameter(scale * torch.randn(width))
         
     | 
| 343 | 
         
            +
                    self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
         
     | 
| 344 | 
         
            +
                    self.ln_pre = LayerNorm(width)
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                    self.transformer = CustomTransformer(width, layers, heads, lora_adapt=lora_adapt, rank=rank)
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                    self.ln_post = LayerNorm(width)
         
     | 
| 349 | 
         
            +
                    self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                def forward(self, x: torch.Tensor, alpha=None, return_attn=False):
         
     | 
| 352 | 
         
            +
                    x = self.conv1(x)  # shape = [*, width, grid, grid]
         
     | 
| 353 | 
         
            +
                    # ASSUME alpha is always not None!
         
     | 
| 354 | 
         
            +
                    x = x + self.conv1_alpha(alpha)
         
     | 
| 355 | 
         
            +
                    x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
         
     | 
| 356 | 
         
            +
                    x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
         
     | 
| 357 | 
         
            +
                    x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
         
     | 
| 358 | 
         
            +
                    x = x + self.positional_embedding.to(x.dtype)
         
     | 
| 359 | 
         
            +
                    x = self.ln_pre(x)
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                    x = x.permute(1, 0, 2)  # NLD -> LND
         
     | 
| 362 | 
         
            +
                    if return_attn:
         
     | 
| 363 | 
         
            +
                        x, attn_last = self.transformer(x, return_attn=True)
         
     | 
| 364 | 
         
            +
                    else:
         
     | 
| 365 | 
         
            +
                        x = self.transformer(x, return_attn=False)
         
     | 
| 366 | 
         
            +
                    x = x.permute(1, 0, 2)  # LND -> NLD
         
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
                    x = self.ln_post(x[:, 0, :])
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                    if self.proj is not None:
         
     | 
| 371 | 
         
            +
                        x = x @ self.proj
         
     | 
| 372 | 
         
            +
                    if return_attn:
         
     | 
| 373 | 
         
            +
                        return x, attn_last
         
     | 
| 374 | 
         
            +
                    else:
         
     | 
| 375 | 
         
            +
                        return x
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
            class CLIP(nn.Module):
         
     | 
| 379 | 
         
            +
                def __init__(self,
         
     | 
| 380 | 
         
            +
                             embed_dim: int,
         
     | 
| 381 | 
         
            +
                             # vision
         
     | 
| 382 | 
         
            +
                             image_resolution: int,
         
     | 
| 383 | 
         
            +
                             vision_layers: Union[Tuple[int, int, int, int], int],
         
     | 
| 384 | 
         
            +
                             vision_width: int,
         
     | 
| 385 | 
         
            +
                             vision_patch_size: int,
         
     | 
| 386 | 
         
            +
                             # text
         
     | 
| 387 | 
         
            +
                             context_length: int,
         
     | 
| 388 | 
         
            +
                             vocab_size: int,
         
     | 
| 389 | 
         
            +
                             transformer_width: int,
         
     | 
| 390 | 
         
            +
                             transformer_heads: int,
         
     | 
| 391 | 
         
            +
                             transformer_layers: int,
         
     | 
| 392 | 
         
            +
                             lora_adapt = False,
         
     | 
| 393 | 
         
            +
                             rank = 16,
         
     | 
| 394 | 
         
            +
                             ):
         
     | 
| 395 | 
         
            +
                    super().__init__()
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
                    self.context_length = context_length
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                    if isinstance(vision_layers, (tuple, list)):
         
     | 
| 400 | 
         
            +
                        vision_heads = vision_width * 32 // 64
         
     | 
| 401 | 
         
            +
                        self.visual = ModifiedResNet(
         
     | 
| 402 | 
         
            +
                            layers=vision_layers,
         
     | 
| 403 | 
         
            +
                            output_dim=embed_dim,
         
     | 
| 404 | 
         
            +
                            heads=vision_heads,
         
     | 
| 405 | 
         
            +
                            input_resolution=image_resolution,
         
     | 
| 406 | 
         
            +
                            width=vision_width
         
     | 
| 407 | 
         
            +
                        )
         
     | 
| 408 | 
         
            +
                    else:
         
     | 
| 409 | 
         
            +
                        vision_heads = vision_width // 64
         
     | 
| 410 | 
         
            +
                        self.visual = VisionTransformer(
         
     | 
| 411 | 
         
            +
                            input_resolution=image_resolution,
         
     | 
| 412 | 
         
            +
                            patch_size=vision_patch_size,
         
     | 
| 413 | 
         
            +
                            width=vision_width,
         
     | 
| 414 | 
         
            +
                            layers=vision_layers,
         
     | 
| 415 | 
         
            +
                            heads=vision_heads,
         
     | 
| 416 | 
         
            +
                            output_dim=embed_dim,
         
     | 
| 417 | 
         
            +
                            lora_adapt=lora_adapt,
         
     | 
| 418 | 
         
            +
                            rank=rank
         
     | 
| 419 | 
         
            +
                        )
         
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
                    self.transformer = Transformer(
         
     | 
| 422 | 
         
            +
                        width=transformer_width,
         
     | 
| 423 | 
         
            +
                        layers=transformer_layers,
         
     | 
| 424 | 
         
            +
                        heads=transformer_heads,
         
     | 
| 425 | 
         
            +
                        attn_mask=self.build_attention_mask()
         
     | 
| 426 | 
         
            +
                    )
         
     | 
| 427 | 
         
            +
             
     | 
| 428 | 
         
            +
                    self.vocab_size = vocab_size
         
     | 
| 429 | 
         
            +
                    self.token_embedding = nn.Embedding(vocab_size, transformer_width)
         
     | 
| 430 | 
         
            +
                    self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
         
     | 
| 431 | 
         
            +
                    self.ln_final = LayerNorm(transformer_width)
         
     | 
| 432 | 
         
            +
             
     | 
| 433 | 
         
            +
                    self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
         
     | 
| 434 | 
         
            +
                    self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
         
     | 
| 435 | 
         
            +
             
     | 
| 436 | 
         
            +
                    self.initialize_parameters()
         
     | 
| 437 | 
         
            +
             
     | 
| 438 | 
         
            +
                def initialize_parameters(self):
         
     | 
| 439 | 
         
            +
                    nn.init.normal_(self.token_embedding.weight, std=0.02)
         
     | 
| 440 | 
         
            +
                    nn.init.normal_(self.positional_embedding, std=0.01)
         
     | 
| 441 | 
         
            +
             
     | 
| 442 | 
         
            +
                    if isinstance(self.visual, ModifiedResNet):
         
     | 
| 443 | 
         
            +
                        if self.visual.attnpool is not None:
         
     | 
| 444 | 
         
            +
                            std = self.visual.attnpool.c_proj.in_features ** -0.5
         
     | 
| 445 | 
         
            +
                            nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
         
     | 
| 446 | 
         
            +
                            nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
         
     | 
| 447 | 
         
            +
                            nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
         
     | 
| 448 | 
         
            +
                            nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
         
     | 
| 449 | 
         
            +
             
     | 
| 450 | 
         
            +
                        for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
         
     | 
| 451 | 
         
            +
                            for name, param in resnet_block.named_parameters():
         
     | 
| 452 | 
         
            +
                                if name.endswith("bn3.weight"):
         
     | 
| 453 | 
         
            +
                                    nn.init.zeros_(param)
         
     | 
| 454 | 
         
            +
             
     | 
| 455 | 
         
            +
                    proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
         
     | 
| 456 | 
         
            +
                    attn_std = self.transformer.width ** -0.5
         
     | 
| 457 | 
         
            +
                    fc_std = (2 * self.transformer.width) ** -0.5
         
     | 
| 458 | 
         
            +
                    for block in self.transformer.resblocks:
         
     | 
| 459 | 
         
            +
                        nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
         
     | 
| 460 | 
         
            +
                        nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
         
     | 
| 461 | 
         
            +
                        nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
         
     | 
| 462 | 
         
            +
                        nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
         
     | 
| 463 | 
         
            +
             
     | 
| 464 | 
         
            +
                    if self.text_projection is not None:
         
     | 
| 465 | 
         
            +
                        nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
         
     | 
| 466 | 
         
            +
             
     | 
| 467 | 
         
            +
                def build_attention_mask(self):
         
     | 
| 468 | 
         
            +
                    # lazily create causal attention mask, with full attention between the vision tokens
         
     | 
| 469 | 
         
            +
                    # pytorch uses additive attention mask; fill with -inf
         
     | 
| 470 | 
         
            +
                    mask = torch.empty(self.context_length, self.context_length)
         
     | 
| 471 | 
         
            +
                    mask.fill_(float("-inf"))
         
     | 
| 472 | 
         
            +
                    mask.triu_(1)  # zero out the lower diagonal
         
     | 
| 473 | 
         
            +
                    return mask
         
     | 
| 474 | 
         
            +
             
     | 
| 475 | 
         
            +
                @property
         
     | 
| 476 | 
         
            +
                def dtype(self):
         
     | 
| 477 | 
         
            +
                    if not hasattr(self.visual, "conv1"):
         
     | 
| 478 | 
         
            +
                        return self.visual.module.conv1.weight.dtype
         
     | 
| 479 | 
         
            +
                    return self.visual.conv1.weight.dtype
         
     | 
| 480 | 
         
            +
             
     | 
| 481 | 
         
            +
                def encode_image(self, image, alpha):
         
     | 
| 482 | 
         
            +
                    assert alpha is not None
         
     | 
| 483 | 
         
            +
                    return self.visual(image.type(self.dtype), alpha.type(self.dtype))
         
     | 
| 484 | 
         
            +
             
     | 
| 485 | 
         
            +
                def encode_text(self, text):
         
     | 
| 486 | 
         
            +
                    x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
         
     | 
| 487 | 
         
            +
             
     | 
| 488 | 
         
            +
                    x = x + self.positional_embedding.type(self.dtype)
         
     | 
| 489 | 
         
            +
                    x = x.permute(1, 0, 2)  # NLD -> LND
         
     | 
| 490 | 
         
            +
                    x = self.transformer(x)
         
     | 
| 491 | 
         
            +
                    x = x.permute(1, 0, 2)  # LND -> NLD
         
     | 
| 492 | 
         
            +
                    x = self.ln_final(x).type(self.dtype)
         
     | 
| 493 | 
         
            +
             
     | 
| 494 | 
         
            +
                    # x.shape = [batch_size, n_ctx, transformer.width]
         
     | 
| 495 | 
         
            +
                    # take features from the eot embedding (eot_token is the highest number in each sequence)
         
     | 
| 496 | 
         
            +
                    x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
         
     | 
| 497 | 
         
            +
             
     | 
| 498 | 
         
            +
                    return x
         
     | 
| 499 | 
         
            +
             
     | 
| 500 | 
         
            +
                def forward(self, image, text, alpha):
         
     | 
| 501 | 
         
            +
                    image_features = self.encode_image(image, alpha)
         
     | 
| 502 | 
         
            +
                    text_features = self.encode_text(text)
         
     | 
| 503 | 
         
            +
             
     | 
| 504 | 
         
            +
                    # normalized features
         
     | 
| 505 | 
         
            +
                    image_features = image_features / image_features.norm(dim=1, keepdim=True)
         
     | 
| 506 | 
         
            +
                    text_features = text_features / text_features.norm(dim=1, keepdim=True)
         
     | 
| 507 | 
         
            +
             
     | 
| 508 | 
         
            +
                    # cosine similarity as logits
         
     | 
| 509 | 
         
            +
                    logit_scale = self.logit_scale.exp()
         
     | 
| 510 | 
         
            +
                    logits_per_image = logit_scale * image_features @ text_features.t()
         
     | 
| 511 | 
         
            +
                    logits_per_text = logits_per_image.t()
         
     | 
| 512 | 
         
            +
             
     | 
| 513 | 
         
            +
                    # shape = [global_batch_size, global_batch_size]
         
     | 
| 514 | 
         
            +
                    return logits_per_image, logits_per_text
         
     | 
| 515 | 
         
            +
             
     | 
| 516 | 
         
            +
             
     | 
| 517 | 
         
            +
            def convert_weights(model: nn.Module):
         
     | 
| 518 | 
         
            +
                """Convert applicable model parameters to fp16"""
         
     | 
| 519 | 
         
            +
             
     | 
| 520 | 
         
            +
                def _convert_weights_to_fp16(l):
         
     | 
| 521 | 
         
            +
                    if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
         
     | 
| 522 | 
         
            +
                        l.weight.data = l.weight.data.half()
         
     | 
| 523 | 
         
            +
                        if l.bias is not None:
         
     | 
| 524 | 
         
            +
                            l.bias.data = l.bias.data.half()
         
     | 
| 525 | 
         
            +
             
     | 
| 526 | 
         
            +
                    if isinstance(l, nn.MultiheadAttention):
         
     | 
| 527 | 
         
            +
                        for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
         
     | 
| 528 | 
         
            +
                            tensor = getattr(l, attr)
         
     | 
| 529 | 
         
            +
                            if tensor is not None:
         
     | 
| 530 | 
         
            +
                                tensor.data = tensor.data.half()
         
     | 
| 531 | 
         
            +
             
     | 
| 532 | 
         
            +
                    for name in ["text_projection", "proj"]:
         
     | 
| 533 | 
         
            +
                        if hasattr(l, name):
         
     | 
| 534 | 
         
            +
                            attr = getattr(l, name)
         
     | 
| 535 | 
         
            +
                            if attr is not None:
         
     | 
| 536 | 
         
            +
                                attr.data = attr.data.half()
         
     | 
| 537 | 
         
            +
             
     | 
| 538 | 
         
            +
                model.apply(_convert_weights_to_fp16)
         
     | 
| 539 | 
         
            +
             
     | 
| 540 | 
         
            +
             
     | 
| 541 | 
         
            +
            def build_model(state_dict: dict, lora_adapt=False, rank=16):
         
     | 
| 542 | 
         
            +
                vit = "visual.proj" in state_dict
         
     | 
| 543 | 
         
            +
             
     | 
| 544 | 
         
            +
                if vit:
         
     | 
| 545 | 
         
            +
                    vision_width = state_dict["visual.conv1.weight"].shape[0]
         
     | 
| 546 | 
         
            +
                    vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
         
     | 
| 547 | 
         
            +
                    vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
         
     | 
| 548 | 
         
            +
                    grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
         
     | 
| 549 | 
         
            +
                    image_resolution = vision_patch_size * grid_size
         
     | 
| 550 | 
         
            +
                else:
         
     | 
| 551 | 
         
            +
                    counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
         
     | 
| 552 | 
         
            +
                    vision_layers = tuple(counts)
         
     | 
| 553 | 
         
            +
                    vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
         
     | 
| 554 | 
         
            +
                    output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
         
     | 
| 555 | 
         
            +
                    vision_patch_size = None
         
     | 
| 556 | 
         
            +
                    assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
         
     | 
| 557 | 
         
            +
                    image_resolution = output_width * 32
         
     | 
| 558 | 
         
            +
             
     | 
| 559 | 
         
            +
                embed_dim = state_dict["text_projection"].shape[1]
         
     | 
| 560 | 
         
            +
                context_length = state_dict["positional_embedding"].shape[0]
         
     | 
| 561 | 
         
            +
                vocab_size = state_dict["token_embedding.weight"].shape[0]
         
     | 
| 562 | 
         
            +
                transformer_width = state_dict["ln_final.weight"].shape[0]
         
     | 
| 563 | 
         
            +
                transformer_heads = transformer_width // 64
         
     | 
| 564 | 
         
            +
                transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
         
     | 
| 565 | 
         
            +
             
     | 
| 566 | 
         
            +
                # always load lora version
         
     | 
| 567 | 
         
            +
                model = CLIP(
         
     | 
| 568 | 
         
            +
                    embed_dim,
         
     | 
| 569 | 
         
            +
                    image_resolution, vision_layers, vision_width, vision_patch_size,
         
     | 
| 570 | 
         
            +
                    context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,
         
     | 
| 571 | 
         
            +
                    lora_adapt=lora_adapt, rank=rank,
         
     | 
| 572 | 
         
            +
                )
         
     | 
| 573 | 
         
            +
             
     | 
| 574 | 
         
            +
                for key in ["input_resolution", "context_length", "vocab_size"]:
         
     | 
| 575 | 
         
            +
                    if key in state_dict:
         
     | 
| 576 | 
         
            +
                        del state_dict[key]
         
     | 
| 577 | 
         
            +
                # para_wb to linear
         
     | 
| 578 | 
         
            +
                new_state_dict = collections.OrderedDict()
         
     | 
| 579 | 
         
            +
                for k, v in state_dict.items():
         
     | 
| 580 | 
         
            +
                    if 'visual' in k:
         
     | 
| 581 | 
         
            +
                        if 'in_proj_weight' in k:
         
     | 
| 582 | 
         
            +
                            new_state_dict[k.replace('in_proj_weight', 'in_proj.weight')] = v
         
     | 
| 583 | 
         
            +
                        elif 'in_proj_bias' in k:
         
     | 
| 584 | 
         
            +
                            new_state_dict[k.replace('in_proj_bias', 'in_proj.bias')] = v
         
     | 
| 585 | 
         
            +
                        else:
         
     | 
| 586 | 
         
            +
                            new_state_dict[k] = v
         
     | 
| 587 | 
         
            +
                    else:
         
     | 
| 588 | 
         
            +
                        new_state_dict[k] = v
         
     | 
| 589 | 
         
            +
                            
         
     | 
| 590 | 
         
            +
                state_dict = new_state_dict
         
     | 
| 591 | 
         
            +
                # add rgba_conv_weight
         
     | 
| 592 | 
         
            +
                if 'visual.conv1_alpha.weight' not in state_dict.keys(): # zero initialization on alpha channel
         
     | 
| 593 | 
         
            +
                    rgb_weight = state_dict['visual.conv1.weight'].clone().detach()
         
     | 
| 594 | 
         
            +
                    rgba_weigth = torch.zeros_like(rgb_weight)[:, 0:1, :, :]
         
     | 
| 595 | 
         
            +
                    state_dict['visual.conv1_alpha.weight'] = rgba_weigth
         
     | 
| 596 | 
         
            +
                convert_weights(model)
         
     | 
| 597 | 
         
            +
                model.load_state_dict(state_dict, strict=False)
         
     | 
| 598 | 
         
            +
                return model.eval()
         
     | 
    	
        AlphaCLIP/alpha_clip/simple_tokenizer.py
    ADDED
    
    | 
         @@ -0,0 +1,132 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import gzip
         
     | 
| 2 | 
         
            +
            import html
         
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
            from functools import lru_cache
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import ftfy
         
     | 
| 7 | 
         
            +
            import regex as re
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            @lru_cache()
         
     | 
| 11 | 
         
            +
            def default_bpe():
         
     | 
| 12 | 
         
            +
                return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            @lru_cache()
         
     | 
| 16 | 
         
            +
            def bytes_to_unicode():
         
     | 
| 17 | 
         
            +
                """
         
     | 
| 18 | 
         
            +
                Returns list of utf-8 byte and a corresponding list of unicode strings.
         
     | 
| 19 | 
         
            +
                The reversible bpe codes work on unicode strings.
         
     | 
| 20 | 
         
            +
                This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
         
     | 
| 21 | 
         
            +
                When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
         
     | 
| 22 | 
         
            +
                This is a signficant percentage of your normal, say, 32K bpe vocab.
         
     | 
| 23 | 
         
            +
                To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
         
     | 
| 24 | 
         
            +
                And avoids mapping to whitespace/control characters the bpe code barfs on.
         
     | 
| 25 | 
         
            +
                """
         
     | 
| 26 | 
         
            +
                bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
         
     | 
| 27 | 
         
            +
                cs = bs[:]
         
     | 
| 28 | 
         
            +
                n = 0
         
     | 
| 29 | 
         
            +
                for b in range(2**8):
         
     | 
| 30 | 
         
            +
                    if b not in bs:
         
     | 
| 31 | 
         
            +
                        bs.append(b)
         
     | 
| 32 | 
         
            +
                        cs.append(2**8+n)
         
     | 
| 33 | 
         
            +
                        n += 1
         
     | 
| 34 | 
         
            +
                cs = [chr(n) for n in cs]
         
     | 
| 35 | 
         
            +
                return dict(zip(bs, cs))
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            def get_pairs(word):
         
     | 
| 39 | 
         
            +
                """Return set of symbol pairs in a word.
         
     | 
| 40 | 
         
            +
                Word is represented as tuple of symbols (symbols being variable-length strings).
         
     | 
| 41 | 
         
            +
                """
         
     | 
| 42 | 
         
            +
                pairs = set()
         
     | 
| 43 | 
         
            +
                prev_char = word[0]
         
     | 
| 44 | 
         
            +
                for char in word[1:]:
         
     | 
| 45 | 
         
            +
                    pairs.add((prev_char, char))
         
     | 
| 46 | 
         
            +
                    prev_char = char
         
     | 
| 47 | 
         
            +
                return pairs
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            def basic_clean(text):
         
     | 
| 51 | 
         
            +
                text = ftfy.fix_text(text)
         
     | 
| 52 | 
         
            +
                text = html.unescape(html.unescape(text))
         
     | 
| 53 | 
         
            +
                return text.strip()
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            def whitespace_clean(text):
         
     | 
| 57 | 
         
            +
                text = re.sub(r'\s+', ' ', text)
         
     | 
| 58 | 
         
            +
                text = text.strip()
         
     | 
| 59 | 
         
            +
                return text
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            class SimpleTokenizer(object):
         
     | 
| 63 | 
         
            +
                def __init__(self, bpe_path: str = default_bpe()):
         
     | 
| 64 | 
         
            +
                    self.byte_encoder = bytes_to_unicode()
         
     | 
| 65 | 
         
            +
                    self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
         
     | 
| 66 | 
         
            +
                    merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
         
     | 
| 67 | 
         
            +
                    merges = merges[1:49152-256-2+1]
         
     | 
| 68 | 
         
            +
                    merges = [tuple(merge.split()) for merge in merges]
         
     | 
| 69 | 
         
            +
                    vocab = list(bytes_to_unicode().values())
         
     | 
| 70 | 
         
            +
                    vocab = vocab + [v+'</w>' for v in vocab]
         
     | 
| 71 | 
         
            +
                    for merge in merges:
         
     | 
| 72 | 
         
            +
                        vocab.append(''.join(merge))
         
     | 
| 73 | 
         
            +
                    vocab.extend(['<|startoftext|>', '<|endoftext|>'])
         
     | 
| 74 | 
         
            +
                    self.encoder = dict(zip(vocab, range(len(vocab))))
         
     | 
| 75 | 
         
            +
                    self.decoder = {v: k for k, v in self.encoder.items()}
         
     | 
| 76 | 
         
            +
                    self.bpe_ranks = dict(zip(merges, range(len(merges))))
         
     | 
| 77 | 
         
            +
                    self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
         
     | 
| 78 | 
         
            +
                    self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                def bpe(self, token):
         
     | 
| 81 | 
         
            +
                    if token in self.cache:
         
     | 
| 82 | 
         
            +
                        return self.cache[token]
         
     | 
| 83 | 
         
            +
                    word = tuple(token[:-1]) + ( token[-1] + '</w>',)
         
     | 
| 84 | 
         
            +
                    pairs = get_pairs(word)
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    if not pairs:
         
     | 
| 87 | 
         
            +
                        return token+'</w>'
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    while True:
         
     | 
| 90 | 
         
            +
                        bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
         
     | 
| 91 | 
         
            +
                        if bigram not in self.bpe_ranks:
         
     | 
| 92 | 
         
            +
                            break
         
     | 
| 93 | 
         
            +
                        first, second = bigram
         
     | 
| 94 | 
         
            +
                        new_word = []
         
     | 
| 95 | 
         
            +
                        i = 0
         
     | 
| 96 | 
         
            +
                        while i < len(word):
         
     | 
| 97 | 
         
            +
                            try:
         
     | 
| 98 | 
         
            +
                                j = word.index(first, i)
         
     | 
| 99 | 
         
            +
                                new_word.extend(word[i:j])
         
     | 
| 100 | 
         
            +
                                i = j
         
     | 
| 101 | 
         
            +
                            except:
         
     | 
| 102 | 
         
            +
                                new_word.extend(word[i:])
         
     | 
| 103 | 
         
            +
                                break
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                            if word[i] == first and i < len(word)-1 and word[i+1] == second:
         
     | 
| 106 | 
         
            +
                                new_word.append(first+second)
         
     | 
| 107 | 
         
            +
                                i += 2
         
     | 
| 108 | 
         
            +
                            else:
         
     | 
| 109 | 
         
            +
                                new_word.append(word[i])
         
     | 
| 110 | 
         
            +
                                i += 1
         
     | 
| 111 | 
         
            +
                        new_word = tuple(new_word)
         
     | 
| 112 | 
         
            +
                        word = new_word
         
     | 
| 113 | 
         
            +
                        if len(word) == 1:
         
     | 
| 114 | 
         
            +
                            break
         
     | 
| 115 | 
         
            +
                        else:
         
     | 
| 116 | 
         
            +
                            pairs = get_pairs(word)
         
     | 
| 117 | 
         
            +
                    word = ' '.join(word)
         
     | 
| 118 | 
         
            +
                    self.cache[token] = word
         
     | 
| 119 | 
         
            +
                    return word
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                def encode(self, text):
         
     | 
| 122 | 
         
            +
                    bpe_tokens = []
         
     | 
| 123 | 
         
            +
                    text = whitespace_clean(basic_clean(text)).lower()
         
     | 
| 124 | 
         
            +
                    for token in re.findall(self.pat, text):
         
     | 
| 125 | 
         
            +
                        token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
         
     | 
| 126 | 
         
            +
                        bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
         
     | 
| 127 | 
         
            +
                    return bpe_tokens
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                def decode(self, tokens):
         
     | 
| 130 | 
         
            +
                    text = ''.join([self.decoder[token] for token in tokens])
         
     | 
| 131 | 
         
            +
                    text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
         
     | 
| 132 | 
         
            +
                    return text
         
     | 
    	
        AlphaCLIP/eval/README.md
    ADDED
    
    | 
         @@ -0,0 +1,6 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Alpha-CLIP evaluation
         
     | 
| 2 | 
         
            +
            ## Zero-Shot Classification on ImageNet-S
         
     | 
| 3 | 
         
            +
            checkout [imagenet_s_zs_test](https://github.com/SunzeY/AlphaCLIP/tree/eval-dev/eval/imagenet_s_zs_test)
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            ## Zero-Shot Referring Expression Comprehension on RefCOCO
         
     | 
| 6 | 
         
            +
            checkout [rec_zs_test](https://github.com/SunzeY/AlphaCLIP/tree/eval-dev/eval/rec_zs_test)
         
     | 
    	
        AlphaCLIP/eval/imagenet_s_zs_test/.gitignore
    ADDED
    
    | 
         @@ -0,0 +1,2 @@ 
     | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            *.json
         
     | 
| 2 | 
         
            +
            data/*
         
     | 
    	
        AlphaCLIP/eval/imagenet_s_zs_test/README.md
    ADDED
    
    | 
         @@ -0,0 +1,21 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Alpha-CLIP evaluation
         
     | 
| 2 | 
         
            +
            ## Zero-Shot Classification on ImageNet-S
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            1.prepare [imagenet-s](https://github.com/LUSSeg/ImageNet-S) dataset, only `validation` raw image is needed.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            2.download [imagenet_919.json](https://download.openxlab.org.cn/models/SunzeY/AlphaCLIP/weight/imagenet_919.json) we provide as data annotation (generated from imagenet-s annotation). The folder should be structured like
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            ```
         
     | 
| 9 | 
         
            +
            ├── imagenet_s_zs_test
         
     | 
| 10 | 
         
            +
            │   ├── data
         
     | 
| 11 | 
         
            +
            │   │   ├── imagenet_919.json
         
     | 
| 12 | 
         
            +
            │   │   └── ImageNetS919
         
     | 
| 13 | 
         
            +
            │   │       └── validation
         
     | 
| 14 | 
         
            +
            ```
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            3.run test script.
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            ```
         
     | 
| 19 | 
         
            +
            cd eval/imagenet_s_zs_test
         
     | 
| 20 | 
         
            +
            python imagenet_s_zs_test.py
         
     | 
| 21 | 
         
            +
            ```
         
     | 
    	
        AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s.py
    ADDED
    
    | 
         @@ -0,0 +1,149 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import json
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            import random
         
     | 
| 4 | 
         
            +
            from tqdm import tqdm
         
     | 
| 5 | 
         
            +
            from torch.utils.data import Dataset
         
     | 
| 6 | 
         
            +
            from pycocotools.coco import COCO
         
     | 
| 7 | 
         
            +
            from pycocotools import mask as maskUtils
         
     | 
| 8 | 
         
            +
            from PIL import Image
         
     | 
| 9 | 
         
            +
            import cv2
         
     | 
| 10 | 
         
            +
            import random
         
     | 
| 11 | 
         
            +
            from torchvision import transforms
         
     | 
| 12 | 
         
            +
            from tqdm import tqdm
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            import pickle
         
     | 
| 15 | 
         
            +
            import torch
         
     | 
| 16 | 
         
            +
            import numpy as np
         
     | 
| 17 | 
         
            +
            import copy
         
     | 
| 18 | 
         
            +
            import sys
         
     | 
| 19 | 
         
            +
            import shutil
         
     | 
| 20 | 
         
            +
            from PIL import Image
         
     | 
| 21 | 
         
            +
            from nltk.corpus import wordnet
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073)
         
     | 
| 24 | 
         
            +
            MASK_FILL = [int(255 * c) for c in PIXEL_MEAN]
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            clip_standard_transform = transforms.Compose([
         
     | 
| 28 | 
         
            +
                transforms.ToTensor(), 
         
     | 
| 29 | 
         
            +
                transforms.Resize((224, 224), interpolation=Image.BICUBIC),
         
     | 
| 30 | 
         
            +
                transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
         
     | 
| 31 | 
         
            +
            ])
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            hi_clip_standard_transform = transforms.Compose([
         
     | 
| 34 | 
         
            +
                transforms.ToTensor(), 
         
     | 
| 35 | 
         
            +
                transforms.Resize((336, 336), interpolation=Image.BICUBIC),
         
     | 
| 36 | 
         
            +
                transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
         
     | 
| 37 | 
         
            +
            ])
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            res_clip_standard_transform = transforms.Compose([
         
     | 
| 40 | 
         
            +
                transforms.ToTensor(), 
         
     | 
| 41 | 
         
            +
                transforms.Resize((336, 336), interpolation=Image.BICUBIC),
         
     | 
| 42 | 
         
            +
                transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
         
     | 
| 43 | 
         
            +
            ])
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            mask_transform = transforms.Compose([
         
     | 
| 46 | 
         
            +
                transforms.ToTensor(), 
         
     | 
| 47 | 
         
            +
                transforms.Resize((224, 224)),
         
     | 
| 48 | 
         
            +
                transforms.Normalize(0.5, 0.26)
         
     | 
| 49 | 
         
            +
            ])
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            hi_mask_transform = transforms.Compose([
         
     | 
| 52 | 
         
            +
                transforms.ToTensor(), 
         
     | 
| 53 | 
         
            +
                transforms.Resize((336, 336)),
         
     | 
| 54 | 
         
            +
                transforms.Normalize(0.5, 0.26)
         
     | 
| 55 | 
         
            +
            ])
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            res_mask_transform = transforms.Compose([
         
     | 
| 58 | 
         
            +
                transforms.ToTensor(), 
         
     | 
| 59 | 
         
            +
                transforms.Resize((336, 336)),
         
     | 
| 60 | 
         
            +
                transforms.Normalize(0.5, 0.26)
         
     | 
| 61 | 
         
            +
            ])
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            def crop_center(img, croph, cropw):
         
     | 
| 64 | 
         
            +
                h, w = img.shape[:2]
         
     | 
| 65 | 
         
            +
                starth = h//2 - (croph//2)
         
     | 
| 66 | 
         
            +
                startw = w//2 - (cropw//2)    
         
     | 
| 67 | 
         
            +
                return img[starth:starth+croph, startw:startw+cropw, :]
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            class Imagenet_S(Dataset):
         
     | 
| 70 | 
         
            +
                def __init__(self, ann_file='data/imagenet_919.json', hi_res=False, all_one=False):
         
     | 
| 71 | 
         
            +
                    self.anns = json.load(open(ann_file, 'r'))
         
     | 
| 72 | 
         
            +
                    self.root_pth = 'data/'
         
     | 
| 73 | 
         
            +
                    cats = []
         
     | 
| 74 | 
         
            +
                    for ann in self.anns:
         
     | 
| 75 | 
         
            +
                        if ann['category_word'] not in cats:
         
     | 
| 76 | 
         
            +
                            cats.append(ann['category_word'])
         
     | 
| 77 | 
         
            +
                        ann['cat_index'] = len(cats) - 1
         
     | 
| 78 | 
         
            +
                    self.classes = []
         
     | 
| 79 | 
         
            +
                    for cat_word in cats:
         
     | 
| 80 | 
         
            +
                        synset = wordnet.synset_from_pos_and_offset('n', int(cat_word[1:]))
         
     | 
| 81 | 
         
            +
                        synonyms = [x.name() for x in synset.lemmas()]
         
     | 
| 82 | 
         
            +
                        self.classes.append(synonyms[0])
         
     | 
| 83 | 
         
            +
                        
         
     | 
| 84 | 
         
            +
                    self.choice = "center_crop"
         
     | 
| 85 | 
         
            +
                    if hi_res:
         
     | 
| 86 | 
         
            +
                        self.mask_transform = res_mask_transform
         
     | 
| 87 | 
         
            +
                        self.clip_standard_transform = res_clip_standard_transform
         
     | 
| 88 | 
         
            +
                    else:
         
     | 
| 89 | 
         
            +
                        self.mask_transform = mask_transform
         
     | 
| 90 | 
         
            +
                        self.clip_standard_transform = clip_standard_transform
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    self.all_one = all_one
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                def __len__(self):
         
     | 
| 95 | 
         
            +
                    return len(self.anns)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 98 | 
         
            +
                    ann = self.anns[index]
         
     | 
| 99 | 
         
            +
                    image = cv2.imread(os.path.join(self.root_pth, ann['image_pth']))
         
     | 
| 100 | 
         
            +
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    mask = maskUtils.decode(ann['mask'])
         
     | 
| 103 | 
         
            +
                    # image[mask==0] = MASK_FILL
         
     | 
| 104 | 
         
            +
                    rgba = np.concatenate((image, np.expand_dims(mask, axis=-1)), axis=-1)
         
     | 
| 105 | 
         
            +
                    h, w = rgba.shape[:2]
         
     | 
| 106 | 
         
            +
                    
         
     | 
| 107 | 
         
            +
                    if self.choice == "padding":
         
     | 
| 108 | 
         
            +
                        if max(h, w) == w:
         
     | 
| 109 | 
         
            +
                            pad = (w - h) // 2
         
     | 
| 110 | 
         
            +
                            l, r = pad, w - h - pad
         
     | 
| 111 | 
         
            +
                            rgba = np.pad(rgba, ((l, r), (0, 0), (0, 0)), 'constant', constant_values=0)
         
     | 
| 112 | 
         
            +
                        else:
         
     | 
| 113 | 
         
            +
                            pad = (h - w) // 2
         
     | 
| 114 | 
         
            +
                            l, r = pad, h - w - pad
         
     | 
| 115 | 
         
            +
                            rgba = np.pad(rgba, ((0, 0), (l, r), (0, 0)), 'constant', constant_values=0)
         
     | 
| 116 | 
         
            +
                    else:
         
     | 
| 117 | 
         
            +
                        if min(h, w) == h:
         
     | 
| 118 | 
         
            +
                            rgba = crop_center(rgba, h, h)
         
     | 
| 119 | 
         
            +
                        else:
         
     | 
| 120 | 
         
            +
                            rgba = crop_center(rgba, w, w)
         
     | 
| 121 | 
         
            +
                    rgb = rgba[:, :, :-1]
         
     | 
| 122 | 
         
            +
                    mask = rgba[:, :, -1]
         
     | 
| 123 | 
         
            +
                    image_torch = self.clip_standard_transform(rgb)
         
     | 
| 124 | 
         
            +
                    # using box: bounding-box compute
         
     | 
| 125 | 
         
            +
                    # bi_mask = mask == 1
         
     | 
| 126 | 
         
            +
                    # h, w = bi_mask.shape[-2:]
         
     | 
| 127 | 
         
            +
                    # in_height = np.max(bi_mask, axis=-1)
         
     | 
| 128 | 
         
            +
                    # in_height_coords = np.max(bi_mask, axis=-1) * np.arange(h)
         
     | 
| 129 | 
         
            +
                    # b_e = in_height_coords.max()
         
     | 
| 130 | 
         
            +
                    # in_height_coords = in_height_coords + h * (~in_height)
         
     | 
| 131 | 
         
            +
                    # t_e = in_height_coords.min()
         
     | 
| 132 | 
         
            +
                    # in_width = np.max(bi_mask, axis=-2)
         
     | 
| 133 | 
         
            +
                    # in_width_coords = np.max(bi_mask, axis=-2) * np.arange(w)
         
     | 
| 134 | 
         
            +
                    # r_e = in_width_coords.max()
         
     | 
| 135 | 
         
            +
                    # in_width_coords = in_width_coords + w * (~in_width)
         
     | 
| 136 | 
         
            +
                    # l_e = in_width_coords.min()
         
     | 
| 137 | 
         
            +
                    # box = np.zeros_like(mask)
         
     | 
| 138 | 
         
            +
                    # box[t_e: b_e, l_e:r_e] = 1
         
     | 
| 139 | 
         
            +
                    # mask = box
         
     | 
| 140 | 
         
            +
                    if self.all_one:
         
     | 
| 141 | 
         
            +
                        mask_torch = self.mask_transform(np.ones_like(mask) * 255)
         
     | 
| 142 | 
         
            +
                    else: 
         
     | 
| 143 | 
         
            +
                        mask_torch = self.mask_transform(mask * 255)
         
     | 
| 144 | 
         
            +
                    return image_torch, mask_torch, ann['cat_index']
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 147 | 
         
            +
                data = Imagenet_S()
         
     | 
| 148 | 
         
            +
                for i in tqdm(range(data.__len__())):
         
     | 
| 149 | 
         
            +
                    data.__getitem__(i)
         
     | 
    	
        AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s_zs_test.py
    ADDED
    
    | 
         @@ -0,0 +1,66 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import alpha_clip
         
     | 
| 3 | 
         
            +
            from tqdm import tqdm
         
     | 
| 4 | 
         
            +
            from imagenet_s import Imagenet_S
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            model, preprocess = alpha_clip.load("ViT-L/14@336px", alpha_vision_ckpt_pth="../../clip_l14@336_grit_20m_4xe.pth")
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            def zeroshot_classifier(classnames, templates):
         
     | 
| 9 | 
         
            +
                with torch.no_grad():
         
     | 
| 10 | 
         
            +
                    zeroshot_weights = []
         
     | 
| 11 | 
         
            +
                    for classname in tqdm(classnames):
         
     | 
| 12 | 
         
            +
                        texts = [template.format(classname) for template in templates] #format with class
         
     | 
| 13 | 
         
            +
                        texts = alpha_clip.tokenize(texts).cuda() #tokenize
         
     | 
| 14 | 
         
            +
                        class_embeddings = model.encode_text(texts) #embed with text encoder
         
     | 
| 15 | 
         
            +
                        class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
         
     | 
| 16 | 
         
            +
                        class_embedding = class_embeddings.mean(dim=0)
         
     | 
| 17 | 
         
            +
                        class_embedding /= class_embedding.norm()
         
     | 
| 18 | 
         
            +
                        zeroshot_weights.append(class_embedding)
         
     | 
| 19 | 
         
            +
                    zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
         
     | 
| 20 | 
         
            +
                return zeroshot_weights
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            dataset = Imagenet_S(hi_res=True)
         
     | 
| 23 | 
         
            +
            loader = torch.utils.data.DataLoader(dataset, batch_size=64, num_workers=2)
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            imagenet_templates = [
         
     | 
| 26 | 
         
            +
                'a photo of a {}.'
         
     | 
| 27 | 
         
            +
            ]
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            zeroshot_weights = zeroshot_classifier(dataset.classes, imagenet_templates)
         
     | 
| 30 | 
         
            +
            temp_corr_dict = dict()
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            with torch.no_grad():
         
     | 
| 33 | 
         
            +
                for i, (images, alpha, target) in enumerate(tqdm(loader)):
         
     | 
| 34 | 
         
            +
                    images = images.cuda()
         
     | 
| 35 | 
         
            +
                    alpha = alpha.cuda()
         
     | 
| 36 | 
         
            +
                    target = target.cuda()
         
     | 
| 37 | 
         
            +
                    # predict
         
     | 
| 38 | 
         
            +
                    image_features = model.encode_image(images, alpha)
         
     | 
| 39 | 
         
            +
                    image_features /= image_features.norm(dim=-1, keepdim=True)
         
     | 
| 40 | 
         
            +
                    score = 100. * image_features @ zeroshot_weights
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    pred = score.topk(1, dim=1)[1].squeeze(dim=1)
         
     | 
| 43 | 
         
            +
                    pred_5 = score.topk(5, dim=1)[1].squeeze(dim=1)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                    for i in range(target.shape[0]):
         
     | 
| 46 | 
         
            +
                        if target[i].item() not in temp_corr_dict:
         
     | 
| 47 | 
         
            +
                            temp_corr_dict[target[i].item()] = [0, 0, 0]
         
     | 
| 48 | 
         
            +
                        temp_corr_dict[target[i].item()][0] += 1
         
     | 
| 49 | 
         
            +
                        if target[i].item() == pred[i].item():
         
     | 
| 50 | 
         
            +
                            temp_corr_dict[target[i].item()][1] += 1
         
     | 
| 51 | 
         
            +
                        if target[i].item() in pred_5[i].tolist():
         
     | 
| 52 | 
         
            +
                            temp_corr_dict[target[i].item()][2] += 1
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            acc1 = 0.0
         
     | 
| 55 | 
         
            +
            acc5 = 0.0
         
     | 
| 56 | 
         
            +
            num_class = 0
         
     | 
| 57 | 
         
            +
            for v in temp_corr_dict.values():
         
     | 
| 58 | 
         
            +
                if v[0] == 0: continue
         
     | 
| 59 | 
         
            +
                acc1 += v[1] / v[0]
         
     | 
| 60 | 
         
            +
                acc5 += v[2] / v[0]
         
     | 
| 61 | 
         
            +
                num_class += 1
         
     | 
| 62 | 
         
            +
            acc1 = acc1 / num_class * 100
         
     | 
| 63 | 
         
            +
            acc5 = acc5 / num_class * 100
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            print(f"Top-1 accuracy: {acc1:.2f}")
         
     | 
| 66 | 
         
            +
            print(f"Top-5 accuracy: {acc5:.2f}")
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/LICENSE.md
    ADDED
    
    | 
         @@ -0,0 +1,201 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
                                             Apache License
         
     | 
| 2 | 
         
            +
                                       Version 2.0, January 2004
         
     | 
| 3 | 
         
            +
                                    http://www.apache.org/licenses/
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
               1. Definitions.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         
     | 
| 10 | 
         
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         
     | 
| 13 | 
         
            +
                  the copyright owner that is granting the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         
     | 
| 16 | 
         
            +
                  other entities that control, are controlled by, or are under common
         
     | 
| 17 | 
         
            +
                  control with that entity. For the purposes of this definition,
         
     | 
| 18 | 
         
            +
                  "control" means (i) the power, direct or indirect, to cause the
         
     | 
| 19 | 
         
            +
                  direction or management of such entity, whether by contract or
         
     | 
| 20 | 
         
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         
     | 
| 21 | 
         
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         
     | 
| 24 | 
         
            +
                  exercising permissions granted by this License.
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                  "Source" form shall mean the preferred form for making modifications,
         
     | 
| 27 | 
         
            +
                  including but not limited to software source code, documentation
         
     | 
| 28 | 
         
            +
                  source, and configuration files.
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                  "Object" form shall mean any form resulting from mechanical
         
     | 
| 31 | 
         
            +
                  transformation or translation of a Source form, including but
         
     | 
| 32 | 
         
            +
                  not limited to compiled object code, generated documentation,
         
     | 
| 33 | 
         
            +
                  and conversions to other media types.
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                  "Work" shall mean the work of authorship, whether in Source or
         
     | 
| 36 | 
         
            +
                  Object form, made available under the License, as indicated by a
         
     | 
| 37 | 
         
            +
                  copyright notice that is included in or attached to the work
         
     | 
| 38 | 
         
            +
                  (an example is provided in the Appendix below).
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         
     | 
| 41 | 
         
            +
                  form, that is based on (or derived from) the Work and for which the
         
     | 
| 42 | 
         
            +
                  editorial revisions, annotations, elaborations, or other modifications
         
     | 
| 43 | 
         
            +
                  represent, as a whole, an original work of authorship. For the purposes
         
     | 
| 44 | 
         
            +
                  of this License, Derivative Works shall not include works that remain
         
     | 
| 45 | 
         
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         
     | 
| 46 | 
         
            +
                  the Work and Derivative Works thereof.
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                  "Contribution" shall mean any work of authorship, including
         
     | 
| 49 | 
         
            +
                  the original version of the Work and any modifications or additions
         
     | 
| 50 | 
         
            +
                  to that Work or Derivative Works thereof, that is intentionally
         
     | 
| 51 | 
         
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         
     | 
| 52 | 
         
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         
     | 
| 53 | 
         
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         
     | 
| 54 | 
         
            +
                  means any form of electronic, verbal, or written communication sent
         
     | 
| 55 | 
         
            +
                  to the Licensor or its representatives, including but not limited to
         
     | 
| 56 | 
         
            +
                  communication on electronic mailing lists, source code control systems,
         
     | 
| 57 | 
         
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         
     | 
| 58 | 
         
            +
                  Licensor for the purpose of discussing and improving the Work, but
         
     | 
| 59 | 
         
            +
                  excluding communication that is conspicuously marked or otherwise
         
     | 
| 60 | 
         
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         
     | 
| 63 | 
         
            +
                  on behalf of whom a Contribution has been received by Licensor and
         
     | 
| 64 | 
         
            +
                  subsequently incorporated within the Work.
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         
     | 
| 67 | 
         
            +
                  this License, each Contributor hereby grants to You a perpetual,
         
     | 
| 68 | 
         
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         
     | 
| 69 | 
         
            +
                  copyright license to reproduce, prepare Derivative Works of,
         
     | 
| 70 | 
         
            +
                  publicly display, publicly perform, sublicense, and distribute the
         
     | 
| 71 | 
         
            +
                  Work and such Derivative Works in Source or Object form.
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         
     | 
| 74 | 
         
            +
                  this License, each Contributor hereby grants to You a perpetual,
         
     | 
| 75 | 
         
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         
     | 
| 76 | 
         
            +
                  (except as stated in this section) patent license to make, have made,
         
     | 
| 77 | 
         
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         
     | 
| 78 | 
         
            +
                  where such license applies only to those patent claims licensable
         
     | 
| 79 | 
         
            +
                  by such Contributor that are necessarily infringed by their
         
     | 
| 80 | 
         
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         
     | 
| 81 | 
         
            +
                  with the Work to which such Contribution(s) was submitted. If You
         
     | 
| 82 | 
         
            +
                  institute patent litigation against any entity (including a
         
     | 
| 83 | 
         
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         
     | 
| 84 | 
         
            +
                  or a Contribution incorporated within the Work constitutes direct
         
     | 
| 85 | 
         
            +
                  or contributory patent infringement, then any patent licenses
         
     | 
| 86 | 
         
            +
                  granted to You under this License for that Work shall terminate
         
     | 
| 87 | 
         
            +
                  as of the date such litigation is filed.
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
               4. Redistribution. You may reproduce and distribute copies of the
         
     | 
| 90 | 
         
            +
                  Work or Derivative Works thereof in any medium, with or without
         
     | 
| 91 | 
         
            +
                  modifications, and in Source or Object form, provided that You
         
     | 
| 92 | 
         
            +
                  meet the following conditions:
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                  (a) You must give any other recipients of the Work or
         
     | 
| 95 | 
         
            +
                      Derivative Works a copy of this License; and
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                  (b) You must cause any modified files to carry prominent notices
         
     | 
| 98 | 
         
            +
                      stating that You changed the files; and
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                  (c) You must retain, in the Source form of any Derivative Works
         
     | 
| 101 | 
         
            +
                      that You distribute, all copyright, patent, trademark, and
         
     | 
| 102 | 
         
            +
                      attribution notices from the Source form of the Work,
         
     | 
| 103 | 
         
            +
                      excluding those notices that do not pertain to any part of
         
     | 
| 104 | 
         
            +
                      the Derivative Works; and
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         
     | 
| 107 | 
         
            +
                      distribution, then any Derivative Works that You distribute must
         
     | 
| 108 | 
         
            +
                      include a readable copy of the attribution notices contained
         
     | 
| 109 | 
         
            +
                      within such NOTICE file, excluding those notices that do not
         
     | 
| 110 | 
         
            +
                      pertain to any part of the Derivative Works, in at least one
         
     | 
| 111 | 
         
            +
                      of the following places: within a NOTICE text file distributed
         
     | 
| 112 | 
         
            +
                      as part of the Derivative Works; within the Source form or
         
     | 
| 113 | 
         
            +
                      documentation, if provided along with the Derivative Works; or,
         
     | 
| 114 | 
         
            +
                      within a display generated by the Derivative Works, if and
         
     | 
| 115 | 
         
            +
                      wherever such third-party notices normally appear. The contents
         
     | 
| 116 | 
         
            +
                      of the NOTICE file are for informational purposes only and
         
     | 
| 117 | 
         
            +
                      do not modify the License. You may add Your own attribution
         
     | 
| 118 | 
         
            +
                      notices within Derivative Works that You distribute, alongside
         
     | 
| 119 | 
         
            +
                      or as an addendum to the NOTICE text from the Work, provided
         
     | 
| 120 | 
         
            +
                      that such additional attribution notices cannot be construed
         
     | 
| 121 | 
         
            +
                      as modifying the License.
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                  You may add Your own copyright statement to Your modifications and
         
     | 
| 124 | 
         
            +
                  may provide additional or different license terms and conditions
         
     | 
| 125 | 
         
            +
                  for use, reproduction, or distribution of Your modifications, or
         
     | 
| 126 | 
         
            +
                  for any such Derivative Works as a whole, provided Your use,
         
     | 
| 127 | 
         
            +
                  reproduction, and distribution of the Work otherwise complies with
         
     | 
| 128 | 
         
            +
                  the conditions stated in this License.
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         
     | 
| 131 | 
         
            +
                  any Contribution intentionally submitted for inclusion in the Work
         
     | 
| 132 | 
         
            +
                  by You to the Licensor shall be under the terms and conditions of
         
     | 
| 133 | 
         
            +
                  this License, without any additional terms or conditions.
         
     | 
| 134 | 
         
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         
     | 
| 135 | 
         
            +
                  the terms of any separate license agreement you may have executed
         
     | 
| 136 | 
         
            +
                  with Licensor regarding such Contributions.
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
               6. Trademarks. This License does not grant permission to use the trade
         
     | 
| 139 | 
         
            +
                  names, trademarks, service marks, or product names of the Licensor,
         
     | 
| 140 | 
         
            +
                  except as required for reasonable and customary use in describing the
         
     | 
| 141 | 
         
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         
     | 
| 144 | 
         
            +
                  agreed to in writing, Licensor provides the Work (and each
         
     | 
| 145 | 
         
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         
     | 
| 146 | 
         
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 147 | 
         
            +
                  implied, including, without limitation, any warranties or conditions
         
     | 
| 148 | 
         
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         
     | 
| 149 | 
         
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         
     | 
| 150 | 
         
            +
                  appropriateness of using or redistributing the Work and assume any
         
     | 
| 151 | 
         
            +
                  risks associated with Your exercise of permissions under this License.
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
               8. Limitation of Liability. In no event and under no legal theory,
         
     | 
| 154 | 
         
            +
                  whether in tort (including negligence), contract, or otherwise,
         
     | 
| 155 | 
         
            +
                  unless required by applicable law (such as deliberate and grossly
         
     | 
| 156 | 
         
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         
     | 
| 157 | 
         
            +
                  liable to You for damages, including any direct, indirect, special,
         
     | 
| 158 | 
         
            +
                  incidental, or consequential damages of any character arising as a
         
     | 
| 159 | 
         
            +
                  result of this License or out of the use or inability to use the
         
     | 
| 160 | 
         
            +
                  Work (including but not limited to damages for loss of goodwill,
         
     | 
| 161 | 
         
            +
                  work stoppage, computer failure or malfunction, or any and all
         
     | 
| 162 | 
         
            +
                  other commercial damages or losses), even if such Contributor
         
     | 
| 163 | 
         
            +
                  has been advised of the possibility of such damages.
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         
     | 
| 166 | 
         
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         
     | 
| 167 | 
         
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         
     | 
| 168 | 
         
            +
                  or other liability obligations and/or rights consistent with this
         
     | 
| 169 | 
         
            +
                  License. However, in accepting such obligations, You may act only
         
     | 
| 170 | 
         
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         
     | 
| 171 | 
         
            +
                  of any other Contributor, and only if You agree to indemnify,
         
     | 
| 172 | 
         
            +
                  defend, and hold each Contributor harmless for any liability
         
     | 
| 173 | 
         
            +
                  incurred by, or claims asserted against, such Contributor by reason
         
     | 
| 174 | 
         
            +
                  of your accepting any such warranty or additional liability.
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
               END OF TERMS AND CONDITIONS
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
               APPENDIX: How to apply the Apache License to your work.
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                  To apply the Apache License to your work, attach the following
         
     | 
| 181 | 
         
            +
                  boilerplate notice, with the fields enclosed by brackets "[]"
         
     | 
| 182 | 
         
            +
                  replaced with your own identifying information. (Don't include
         
     | 
| 183 | 
         
            +
                  the brackets!)  The text should be enclosed in the appropriate
         
     | 
| 184 | 
         
            +
                  comment syntax for the file format. We also recommend that a
         
     | 
| 185 | 
         
            +
                  file or class name and description of purpose be included on the
         
     | 
| 186 | 
         
            +
                  same "printed page" as the copyright notice for easier
         
     | 
| 187 | 
         
            +
                  identification within third-party archives.
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
               Copyright [yyyy] [name of copyright owner]
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 192 | 
         
            +
               you may not use this file except in compliance with the License.
         
     | 
| 193 | 
         
            +
               You may obtain a copy of the License at
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
               Unless required by applicable law or agreed to in writing, software
         
     | 
| 198 | 
         
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 199 | 
         
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 200 | 
         
            +
               See the License for the specific language governing permissions and
         
     | 
| 201 | 
         
            +
               limitations under the License.
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/README.md
    ADDED
    
    | 
         @@ -0,0 +1,74 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ## Zero-Shot Referring Expression Comprehension on RefCOCO
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            **Preparing Data**
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            1.Download [images for RefCOCO/g/+](http://images.cocodataset.org/zips/train2014.zip). Put downloaded dataset(train2014) to eval/rec_zs_test/data/.
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            2.Download preprocessed data files via `gsutil cp gs://reclip-sanjays/reclip_data.tar.gz` and `cd rec_zs_test`, and then extract the data using `tar -xvzf reclip_data.tar.gz`. 
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            **Preparing model**
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            3.Download [SAM](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) (vit-h), [Alpha-CLIP](https://github.com/SunzeY/AlphaCLIP/blob/main/model-zoo.md) model, and put them in ./eval/rec_zs_test/ckpt.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            ```
         
     | 
| 14 | 
         
            +
            ├── eval
         
     | 
| 15 | 
         
            +
            │   ├── rec_zs_test
         
     | 
| 16 | 
         
            +
            │   │   ├── data
         
     | 
| 17 | 
         
            +
            │   │       └── train2014
         
     | 
| 18 | 
         
            +
            │   │   ├── reclip_data
         
     | 
| 19 | 
         
            +
            │   │       └── refcoco_val.jsonl
         
     | 
| 20 | 
         
            +
            │   │       └── refcoco_dets_dict.json
         
     | 
| 21 | 
         
            +
            │   │           ...
         
     | 
| 22 | 
         
            +
            │   │   ├── ckpt
         
     | 
| 23 | 
         
            +
            │   │       └── sam_vit_h_4b8939.pth
         
     | 
| 24 | 
         
            +
            │   │       └── grit1m
         
     | 
| 25 | 
         
            +
            │   │           └── clip_b16_grit+mim_fultune_4xe.pth
         
     | 
| 26 | 
         
            +
            │   │           └── clip_l14_grit+mim_fultune_6xe.pth
         
     | 
| 27 | 
         
            +
            │   │   ├── methods
         
     | 
| 28 | 
         
            +
            │   │   ├── cache
         
     | 
| 29 | 
         
            +
            │   │   ├── output
         
     | 
| 30 | 
         
            +
            │   │   ├── main.py
         
     | 
| 31 | 
         
            +
            │   │   ├── executor.py
         
     | 
| 32 | 
         
            +
            │   │   ├── run.sh
         
     | 
| 33 | 
         
            +
            │   │   ├── ...
         
     | 
| 34 | 
         
            +
            ```
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            4.run test script.
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            ```
         
     | 
| 39 | 
         
            +
            cd eval/rec_zs_test
         
     | 
| 40 | 
         
            +
            ```
         
     | 
| 41 | 
         
            +
            ```
         
     | 
| 42 | 
         
            +
            bash run.sh
         
     | 
| 43 | 
         
            +
            ```
         
     | 
| 44 | 
         
            +
            or
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            ```
         
     | 
| 47 | 
         
            +
            python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_representation_method full,blur --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --detector_file reclip_data/refcoco+_dets_dict.json --cache_path ./cache
         
     | 
| 48 | 
         
            +
            ```
         
     | 
| 49 | 
         
            +
            (We recommend using `cache_path` to reduce time to generate mask by SAM for a image repeatedly.`)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            For multi-gpus testing, try:
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            ```
         
     | 
| 54 | 
         
            +
            bash run_multi_gpus.sh
         
     | 
| 55 | 
         
            +
            python cal_acc.py refcoco_val
         
     | 
| 56 | 
         
            +
            ```
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            **Acknowledgement**
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            We test our model based on the wonderful work [ReCLIP](https://github.com/allenai/reclip/tree/main). We simply replace CLIP with Alpha-CLIP; and skip the image-cropping operation.
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            **Experiment results**
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            | Method         | RefCOCO |      |      | RefCOCO+ |      |      | RefCOCOg |      |
         
     | 
| 68 | 
         
            +
            |----------------|---------|------|------|----------|------|------|----------|------|
         
     | 
| 69 | 
         
            +
            |                | Val     | TestA| TestB| Val      | TestA| TestB| Val      | Test |
         
     | 
| 70 | 
         
            +
            | CPT [67]       | 32.2    | 36.1 | 30.3 | 31.9     | 35.2 | 28.8 | 36.7     | 36.5 |
         
     | 
| 71 | 
         
            +
            | ReCLIP [54]    | 45.8    | 46.1 | 47.1 | 47.9     | 50.1 | 45.1 | 59.3     | 59.0 |
         
     | 
| 72 | 
         
            +
            | Red Circle [52]| 49.8    | 58.6 | 39.9 | 55.3     | 63.9 | 45.4 | 59.4     | 58.9 |
         
     | 
| 73 | 
         
            +
            | Alpha-CLIP     | 55.7    | 61.1 | 50.3 | 55.6     | 62.7 | 46.4 | 61.2     | 62.0 |
         
     | 
| 74 | 
         
            +
             
     | 
    	
        AlphaCLIP/eval/rec_zs_test/cache/.gitkeep
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/cal_acc.py
    ADDED
    
    | 
         @@ -0,0 +1,21 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import json
         
     | 
| 2 | 
         
            +
            import argparse
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            parser = argparse.ArgumentParser()
         
     | 
| 5 | 
         
            +
            parser.add_argument('name', type=str, default='refcoco_val')
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            args = parser.parse_args()
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            name = args.name
         
     | 
| 10 | 
         
            +
            print(name)
         
     | 
| 11 | 
         
            +
            count = 0
         
     | 
| 12 | 
         
            +
            all_count = 0
         
     | 
| 13 | 
         
            +
            for i in range(8):
         
     | 
| 14 | 
         
            +
                pth = f'output/{name}_count_{i}.json'
         
     | 
| 15 | 
         
            +
                acc = json.load(open(pth, 'r'))
         
     | 
| 16 | 
         
            +
                a_list = acc.split()
         
     | 
| 17 | 
         
            +
                a, b = a_list[0], a_list[1]
         
     | 
| 18 | 
         
            +
                count += int(a)
         
     | 
| 19 | 
         
            +
                all_count += int(b)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            print(float(count) / float(all_count))
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/ckpt/.gitkeep
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/data/.gitkeep
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/entity_extraction.py
    ADDED
    
    | 
         @@ -0,0 +1,142 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Dict, Any, Callable, List, Tuple, NamedTuple, Text, Optional
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
            from spacy.tokens.token import Token
         
     | 
| 4 | 
         
            +
            from spacy.tokens.span import Span
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from lattice import Product as L
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from heuristics import Heuristics
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            Rel = Tuple[List[Token], "Entity"]
         
     | 
| 11 | 
         
            +
            Sup = List[Token]
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            DEFAULT_HEURISTICS = Heuristics()
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            def find_superlatives(tokens, heuristics) -> List[Sup]:
         
     | 
| 17 | 
         
            +
                """Modify and return a list of superlative tokens."""
         
     | 
| 18 | 
         
            +
                for heuristic in heuristics.superlatives:
         
     | 
| 19 | 
         
            +
                    if any(tok.text in heuristic.keywords for tok in tokens):
         
     | 
| 20 | 
         
            +
                        tokens.sort(key=lambda tok: tok.i)
         
     | 
| 21 | 
         
            +
                        return [tokens]
         
     | 
| 22 | 
         
            +
                return []
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            def expand_chunks(doc, chunks):
         
     | 
| 25 | 
         
            +
                expanded = {}
         
     | 
| 26 | 
         
            +
                for key in chunks:
         
     | 
| 27 | 
         
            +
                    chunk = chunks[key]
         
     | 
| 28 | 
         
            +
                    start = chunk.start
         
     | 
| 29 | 
         
            +
                    end = chunk.end
         
     | 
| 30 | 
         
            +
                    for i in range(chunk.start-1, -1, -1):
         
     | 
| 31 | 
         
            +
                        if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)):
         
     | 
| 32 | 
         
            +
                            if not any(any(doc[i].is_ancestor(doc[j]) for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2):
         
     | 
| 33 | 
         
            +
                                start = i
         
     | 
| 34 | 
         
            +
                    for i in range(chunk.end, len(doc)):
         
     | 
| 35 | 
         
            +
                        if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)):
         
     | 
| 36 | 
         
            +
                            if not any(any(doc[i].is_ancestor(doc[j]) or i == j for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2):
         
     | 
| 37 | 
         
            +
                                end = i+1
         
     | 
| 38 | 
         
            +
                            else:
         
     | 
| 39 | 
         
            +
                                break
         
     | 
| 40 | 
         
            +
                    expanded[key] = Span(doc=doc, start=start, end=end)
         
     | 
| 41 | 
         
            +
                return expanded
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            class Entity(NamedTuple):
         
     | 
| 44 | 
         
            +
                """Represents an entity with locative constraints extracted from the parse."""
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                head: Span
         
     | 
| 47 | 
         
            +
                relations: List[Rel]
         
     | 
| 48 | 
         
            +
                superlatives: List[Sup]
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                @classmethod
         
     | 
| 51 | 
         
            +
                def extract(cls, head, chunks, heuristics: Optional[Heuristics] = None) -> "Entity":
         
     | 
| 52 | 
         
            +
                    """Extract entities from a spacy parse.
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    Jointly recursive with `_get_rel_sups`."""
         
     | 
| 55 | 
         
            +
                    if heuristics is None:
         
     | 
| 56 | 
         
            +
                        heuristics = DEFAULT_HEURISTICS
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    if head.i not in chunks:
         
     | 
| 59 | 
         
            +
                        # Handles predicative cases.
         
     | 
| 60 | 
         
            +
                        children = list(head.children)
         
     | 
| 61 | 
         
            +
                        if children and children[0].i in chunks:
         
     | 
| 62 | 
         
            +
                            head = children[0]
         
     | 
| 63 | 
         
            +
                            # TODO: Also extract predicative relations.
         
     | 
| 64 | 
         
            +
                        else:
         
     | 
| 65 | 
         
            +
                            return None
         
     | 
| 66 | 
         
            +
                    hchunk = chunks[head.i]
         
     | 
| 67 | 
         
            +
                    rels, sups = cls._get_rel_sups(head, head, [], chunks, heuristics)
         
     | 
| 68 | 
         
            +
                    return cls(hchunk, rels, sups)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                @classmethod
         
     | 
| 71 | 
         
            +
                def _get_rel_sups(cls, token, head, tokens, chunks, heuristics) -> Tuple[List[Rel], List[Sup]]:
         
     | 
| 72 | 
         
            +
                    hchunk = chunks[head.i]
         
     | 
| 73 | 
         
            +
                    is_keyword = any(token.text in h.keywords for h in heuristics.relations)
         
     | 
| 74 | 
         
            +
                    is_keyword |= token.text in heuristics.null_keywords
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    # Found another entity head.
         
     | 
| 77 | 
         
            +
                    if token.i in chunks and chunks[token.i] is not hchunk and not is_keyword:
         
     | 
| 78 | 
         
            +
                        tchunk = chunks[token.i]
         
     | 
| 79 | 
         
            +
                        tokens.sort(key=lambda tok: tok.i)
         
     | 
| 80 | 
         
            +
                        subhead = cls.extract(token, chunks, heuristics)
         
     | 
| 81 | 
         
            +
                        return [(tokens, subhead)], []
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                    # End of a chain of modifiers.
         
     | 
| 84 | 
         
            +
                    n_children = len(list(token.children))
         
     | 
| 85 | 
         
            +
                    if n_children == 0:
         
     | 
| 86 | 
         
            +
                        return [], find_superlatives(tokens + [token], heuristics)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    relations = []
         
     | 
| 89 | 
         
            +
                    superlatives = []
         
     | 
| 90 | 
         
            +
                    is_keyword |= any(token.text in h.keywords for h in heuristics.superlatives)
         
     | 
| 91 | 
         
            +
                    for child in token.children:
         
     | 
| 92 | 
         
            +
                        if token.i in chunks and child.i in chunks and chunks[token.i] is chunks[child.i]:
         
     | 
| 93 | 
         
            +
                            if not any(child.text in h.keywords for h in heuristics.superlatives):
         
     | 
| 94 | 
         
            +
                                if n_children == 1:
         
     | 
| 95 | 
         
            +
                                    # Catches "the goat on the left"
         
     | 
| 96 | 
         
            +
                                    sups = find_superlatives(tokens + [token], heuristics)
         
     | 
| 97 | 
         
            +
                                    superlatives.extend(sups)
         
     | 
| 98 | 
         
            +
                                continue
         
     | 
| 99 | 
         
            +
                        new_tokens = tokens + [token] if token.i not in chunks or is_keyword else tokens
         
     | 
| 100 | 
         
            +
                        subrel, subsup = cls._get_rel_sups(child, head, new_tokens, chunks, heuristics)
         
     | 
| 101 | 
         
            +
                        relations.extend(subrel)
         
     | 
| 102 | 
         
            +
                        superlatives.extend(subsup)
         
     | 
| 103 | 
         
            +
                    return relations, superlatives
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                def expand(self, span: Span = None):
         
     | 
| 106 | 
         
            +
                    tokens = [token for token in self.head]
         
     | 
| 107 | 
         
            +
                    if span is None:
         
     | 
| 108 | 
         
            +
                        span = [None]
         
     | 
| 109 | 
         
            +
                    for target_token in span:
         
     | 
| 110 | 
         
            +
                        include = False
         
     | 
| 111 | 
         
            +
                        stack = [token for token in self.head]
         
     | 
| 112 | 
         
            +
                        while len(stack) > 0:
         
     | 
| 113 | 
         
            +
                            token = stack.pop()
         
     | 
| 114 | 
         
            +
                            if token == target_token:
         
     | 
| 115 | 
         
            +
                                token2 = target_token.head
         
     | 
| 116 | 
         
            +
                                while token2.head != token2:
         
     | 
| 117 | 
         
            +
                                    tokens.append(token2)
         
     | 
| 118 | 
         
            +
                                    token2 = token2.head
         
     | 
| 119 | 
         
            +
                                tokens.append(token2)
         
     | 
| 120 | 
         
            +
                                stack = []
         
     | 
| 121 | 
         
            +
                                include = True
         
     | 
| 122 | 
         
            +
                            if target_token is None or include:
         
     | 
| 123 | 
         
            +
                                tokens.append(token)
         
     | 
| 124 | 
         
            +
                            for child in token.children:
         
     | 
| 125 | 
         
            +
                                stack.append(child)
         
     | 
| 126 | 
         
            +
                    tokens = list(set(tokens))
         
     | 
| 127 | 
         
            +
                    tokens = sorted(tokens, key=lambda x: x.i)
         
     | 
| 128 | 
         
            +
                    return ' '.join([token.text for token in tokens])
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                def __eq__(self, other: "Entity") -> bool:
         
     | 
| 131 | 
         
            +
                    if self.text != other.text:
         
     | 
| 132 | 
         
            +
                        return False
         
     | 
| 133 | 
         
            +
                    if self.relations != other.relations:
         
     | 
| 134 | 
         
            +
                        return False
         
     | 
| 135 | 
         
            +
                    if self.superlatives != other.superlatives:
         
     | 
| 136 | 
         
            +
                        return False
         
     | 
| 137 | 
         
            +
                    return True
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                @property
         
     | 
| 140 | 
         
            +
                def text(self) -> Text:
         
     | 
| 141 | 
         
            +
                    """Get the text predicate associated with this entity."""
         
     | 
| 142 | 
         
            +
                    return self.head.text
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/executor.py
    ADDED
    
    | 
         @@ -0,0 +1,401 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import List, Dict, Union, Tuple
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from PIL import Image, ImageDraw, ImageFilter, ImageOps, ImageEnhance
         
     | 
| 4 | 
         
            +
            import spacy
         
     | 
| 5 | 
         
            +
            import hashlib
         
     | 
| 6 | 
         
            +
            import os
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            import torchvision
         
     | 
| 10 | 
         
            +
            import torchvision.transforms as transforms
         
     | 
| 11 | 
         
            +
            import clip
         
     | 
| 12 | 
         
            +
            from transformers import BertTokenizer, RobertaTokenizerFast
         
     | 
| 13 | 
         
            +
            import ruamel.yaml as yaml
         
     | 
| 14 | 
         
            +
            import copy
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from interpreter import Box
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import pycocotools.mask as mask_utils
         
     | 
| 19 | 
         
            +
            import alpha_clip
         
     | 
| 20 | 
         
            +
            from segment_anything import sam_model_registry, SamPredictor
         
     | 
| 21 | 
         
            +
            import numpy as np
         
     | 
| 22 | 
         
            +
            import cv2
         
     | 
| 23 | 
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            import pickle
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            class Executor:
         
     | 
| 28 | 
         
            +
                def __init__(self, device: str = "cpu", box_representation_method: str = "crop", method_aggregator: str = "max", enlarge_boxes: int = 0, expand_position_embedding: bool = False, square_size: bool = False, blur_std_dev: int = 100, cache_path: str = None, input_file: str = None) -> None:
         
     | 
| 29 | 
         
            +
                    IMPLEMENTED_METHODS = ["blur", "full", "gray"]
         
     | 
| 30 | 
         
            +
                    if any(m not in IMPLEMENTED_METHODS for m in box_representation_method.split(",")):
         
     | 
| 31 | 
         
            +
                        raise NotImplementedError
         
     | 
| 32 | 
         
            +
                    IMPLEMENTED_AGGREGATORS = ["max", "sum"]
         
     | 
| 33 | 
         
            +
                    if method_aggregator not in IMPLEMENTED_AGGREGATORS:
         
     | 
| 34 | 
         
            +
                        raise NotImplementedError
         
     | 
| 35 | 
         
            +
                    self.box_representation_method = box_representation_method
         
     | 
| 36 | 
         
            +
                    self.method_aggregator = method_aggregator
         
     | 
| 37 | 
         
            +
                    self.enlarge_boxes = enlarge_boxes
         
     | 
| 38 | 
         
            +
                    self.device = device
         
     | 
| 39 | 
         
            +
                    self.expand_position_embedding = expand_position_embedding
         
     | 
| 40 | 
         
            +
                    self.square_size = square_size
         
     | 
| 41 | 
         
            +
                    self.blur_std_dev = blur_std_dev
         
     | 
| 42 | 
         
            +
                    self.cache_path = cache_path
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                def preprocess_image(self, image: Image) -> List[torch.Tensor]:
         
     | 
| 45 | 
         
            +
                    return [preprocess(image) for preprocess in self.preprocesses]
         
     | 
| 46 | 
         
            +
                
         
     | 
| 47 | 
         
            +
                def preprocess_mask(self, mask: Image) -> List[torch.Tensor]:
         
     | 
| 48 | 
         
            +
                    preprocess = self.preprocesses[0]
         
     | 
| 49 | 
         
            +
                    return preprocess.transforms[1](preprocess.transforms[0](mask)) 
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                def preprocess_text(self, text: str) -> torch.Tensor:
         
     | 
| 52 | 
         
            +
                    raise NotImplementedError
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                def call_model(self, model: torch.nn.Module, images: torch.Tensor, text: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
         
     | 
| 55 | 
         
            +
                    raise NotImplementedError
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                def tensorize_inputs(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth: str = None) -> Tuple[List[torch.Tensor], torch.Tensor]:
         
     | 
| 58 | 
         
            +
                    images = []
         
     | 
| 59 | 
         
            +
                    for preprocess in self.preprocesses:
         
     | 
| 60 | 
         
            +
                        images.append([])
         
     | 
| 61 | 
         
            +
                    
         
     | 
| 62 | 
         
            +
                    if 'aclip' in self.clip_type:
         
     | 
| 63 | 
         
            +
                        self.all_masks = []
         
     | 
| 64 | 
         
            +
                        read_save = False
         
     | 
| 65 | 
         
            +
                        if self.mask_path is not None: # load mask if cached
         
     | 
| 66 | 
         
            +
                            file_name = image_pth.split('/')[-1].split('.')[0]+'.pkl'
         
     | 
| 67 | 
         
            +
                            if os.path.exists(os.path.join(self.mask_path, file_name)):
         
     | 
| 68 | 
         
            +
                                all_rles = pickle.load(open(os.path.join(self.mask_path, file_name),'rb'))
         
     | 
| 69 | 
         
            +
                                for rle in all_rles:
         
     | 
| 70 | 
         
            +
                                    mask = np.array(mask_utils.decode(rle), dtype=bool)
         
     | 
| 71 | 
         
            +
                                    self.all_masks.append(mask)
         
     | 
| 72 | 
         
            +
                                read_save = True 
         
     | 
| 73 | 
         
            +
                        if not read_save:
         
     | 
| 74 | 
         
            +
                            # use SAM to generate masks
         
     | 
| 75 | 
         
            +
                            self.predictor.set_image(np.array(image.convert('RGB')))
         
     | 
| 76 | 
         
            +
                            all_rles = []
         
     | 
| 77 | 
         
            +
                            for i in range(len(boxes)):
         
     | 
| 78 | 
         
            +
                                box = [
         
     | 
| 79 | 
         
            +
                                    max(boxes[i].left-self.enlarge_boxes, 0),
         
     | 
| 80 | 
         
            +
                                    max(boxes[i].top-self.enlarge_boxes, 0),
         
     | 
| 81 | 
         
            +
                                    min(boxes[i].right+self.enlarge_boxes, image.width),
         
     | 
| 82 | 
         
            +
                                    min(boxes[i].bottom+self.enlarge_boxes, image.height)
         
     | 
| 83 | 
         
            +
                                ] # box prompt
         
     | 
| 84 | 
         
            +
                                input_box = np.array(box)
         
     | 
| 85 | 
         
            +
                                masks, _, _ = self.predictor.predict(
         
     | 
| 86 | 
         
            +
                                    point_coords=None,
         
     | 
| 87 | 
         
            +
                                    point_labels=None,
         
     | 
| 88 | 
         
            +
                                    box=input_box[None, :],
         
     | 
| 89 | 
         
            +
                                    multimask_output=False,
         
     | 
| 90 | 
         
            +
                                )
         
     | 
| 91 | 
         
            +
                                self.all_masks.append(masks[0])
         
     | 
| 92 | 
         
            +
                                rle = mask_utils.encode(np.array(masks[0][:, :, None], order='F', dtype="uint8"))[0]
         
     | 
| 93 | 
         
            +
                                rle["counts"] = rle["counts"].decode("utf-8")
         
     | 
| 94 | 
         
            +
                                all_rles.append(rle)
         
     | 
| 95 | 
         
            +
                            if self.mask_path is not None: # save mask
         
     | 
| 96 | 
         
            +
                                os.makedirs(self.mask_path, exist_ok=True)
         
     | 
| 97 | 
         
            +
                                pickle.dump(all_rles, open(os.path.join(self.mask_path, file_name),'wb'))
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    if self.cache_path is None or any([not os.path.exists(os.path.join(self.cache_path, "refcoco_val", model_name, "image", image_name, method_name+".pt")) for model_name in self.model_names for method_name in self.box_representation_method.split(',')]): 
         
     | 
| 100 | 
         
            +
                        if "full" in self.box_representation_method: # original full image with alpha-map
         
     | 
| 101 | 
         
            +
                            for i in range(len(boxes)):
         
     | 
| 102 | 
         
            +
                                image_i = image.copy()
         
     | 
| 103 | 
         
            +
                                preprocessed_images = self.preprocess_image(image_i)
         
     | 
| 104 | 
         
            +
                                for j, img in enumerate(preprocessed_images):
         
     | 
| 105 | 
         
            +
                                    images[j].append(img.to(self.device))
         
     | 
| 106 | 
         
            +
                        if "blur" in self.box_representation_method:
         
     | 
| 107 | 
         
            +
                            for i in range(len(boxes)):
         
     | 
| 108 | 
         
            +
                                image_i = image.copy()
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                                mask = Image.new('L', image_i.size, 0)
         
     | 
| 111 | 
         
            +
                                draw = ImageDraw.Draw(mask)
         
     | 
| 112 | 
         
            +
                                box = (
         
     | 
| 113 | 
         
            +
                                    max(boxes[i].left-self.enlarge_boxes, 0),
         
     | 
| 114 | 
         
            +
                                    max(boxes[i].top-self.enlarge_boxes, 0),
         
     | 
| 115 | 
         
            +
                                    min(boxes[i].right+self.enlarge_boxes, image_i.width),
         
     | 
| 116 | 
         
            +
                                    min(boxes[i].bottom+self.enlarge_boxes, image_i.height)
         
     | 
| 117 | 
         
            +
                                )
         
     | 
| 118 | 
         
            +
                                if 'aclip' in self.clip_type:
         
     | 
| 119 | 
         
            +
                                    width, height = image.size
         
     | 
| 120 | 
         
            +
                                    for y in range(height):
         
     | 
| 121 | 
         
            +
                                        for x in range(width):
         
     | 
| 122 | 
         
            +
                                            if self.all_masks[i][y][x] == 1:
         
     | 
| 123 | 
         
            +
                                                draw.point((x, y), fill=255)
         
     | 
| 124 | 
         
            +
                                else:
         
     | 
| 125 | 
         
            +
                                    draw.rectangle([box[:2], box[2:]], fill=255)
         
     | 
| 126 | 
         
            +
                                blurred = image_i.filter(ImageFilter.GaussianBlur(self.blur_std_dev))
         
     | 
| 127 | 
         
            +
                                blurred.paste(image_i, mask=mask)
         
     | 
| 128 | 
         
            +
                                preprocessed_images = self.preprocess_image(blurred)
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                                for j, img in enumerate(preprocessed_images):
         
     | 
| 131 | 
         
            +
                                    images[j].append(img.to(self.device))
         
     | 
| 132 | 
         
            +
                        if "gray" in self.box_representation_method:
         
     | 
| 133 | 
         
            +
                            for i in range(len(boxes)):
         
     | 
| 134 | 
         
            +
                                image_i = image.copy()
         
     | 
| 135 | 
         
            +
                                mask_i = self.all_masks[i]
         
     | 
| 136 | 
         
            +
                                width, height = image.size
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                                pixels = image_i.load()
         
     | 
| 139 | 
         
            +
                                for y in range(height):
         
     | 
| 140 | 
         
            +
                                    for x in range(width):
         
     | 
| 141 | 
         
            +
                                        if mask_i[y][x] == 0:
         
     | 
| 142 | 
         
            +
                                            pixel_value = pixels[x, y]
         
     | 
| 143 | 
         
            +
                                            gray_value = int(0.2989 * pixel_value[0] + 0.5870 * pixel_value[1] + 0.1140 * pixel_value[2])
         
     | 
| 144 | 
         
            +
                                            pixels[x, y] = (gray_value, gray_value, gray_value)
         
     | 
| 145 | 
         
            +
                                preprocessed_images = self.preprocess_image(image_i)
         
     | 
| 146 | 
         
            +
                                for j, img in enumerate(preprocessed_images):
         
     | 
| 147 | 
         
            +
                                    images[j].append(img.to(self.device))
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                        imgs = [torch.stack(image_list) for image_list in images]
         
     | 
| 150 | 
         
            +
                    else:
         
     | 
| 151 | 
         
            +
                        imgs = [[] for _ in self.models]
         
     | 
| 152 | 
         
            +
                    text_tensor = self.preprocess_text(caption.lower()).to(self.device)
         
     | 
| 153 | 
         
            +
                    return imgs, text_tensor
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                @torch.no_grad()
         
     | 
| 156 | 
         
            +
                def __call__(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth=None) -> torch.Tensor:
         
     | 
| 157 | 
         
            +
                    images, text_tensor = self.tensorize_inputs(caption, image, boxes, image_name, image_pth) 
         
     | 
| 158 | 
         
            +
                    all_logits_per_image = []
         
     | 
| 159 | 
         
            +
                    all_logits_per_text = []
         
     | 
| 160 | 
         
            +
                    box_representation_methods = self.box_representation_method.split(',')
         
     | 
| 161 | 
         
            +
                    caption_hash = hashlib.md5(caption.encode('utf-8')).hexdigest() 
         
     | 
| 162 | 
         
            +
                    for model, images_t, model_name in zip(self.models, images, self.model_names):
         
     | 
| 163 | 
         
            +
                        self.image_feat_path = ""
         
     | 
| 164 | 
         
            +
                        if self.cache_path is not None:
         
     | 
| 165 | 
         
            +
                            text_cache_path = os.path.join(self.cache_path, "refcoco_val", model_name, "text"+("_shade" if self.box_representation_method == "shade" else ""))
         
     | 
| 166 | 
         
            +
                            image_feat_path = os.path.join(self.cache_path, "refcoco_val", model_name, "image", image_name)
         
     | 
| 167 | 
         
            +
                            self.image_feat_path = image_feat_path
         
     | 
| 168 | 
         
            +
                        image_features = None
         
     | 
| 169 | 
         
            +
                        text_features = None
         
     | 
| 170 | 
         
            +
                        if self.cache_path is not None and os.path.exists(os.path.join(self.cache_path, "refcoco_val", model_name)):
         
     | 
| 171 | 
         
            +
                            if os.path.exists(os.path.join(text_cache_path, caption_hash+".pt")):
         
     | 
| 172 | 
         
            +
                                text_features = torch.load(os.path.join(text_cache_path, caption_hash+".pt"), map_location=self.device)
         
     | 
| 173 | 
         
            +
                            if os.path.exists(image_feat_path): 
         
     | 
| 174 | 
         
            +
                                if all([os.path.exists(os.path.join(image_feat_path, method_name+".pt")) for method_name in box_representation_methods]):
         
     | 
| 175 | 
         
            +
                                    image_features = []
         
     | 
| 176 | 
         
            +
                                    for method_name in box_representation_methods:
         
     | 
| 177 | 
         
            +
                                        features = torch.load(os.path.join(image_feat_path, method_name+".pt"), map_location=self.device)
         
     | 
| 178 | 
         
            +
                                        image_features.append(torch.stack([
         
     | 
| 179 | 
         
            +
                                            features[(box.x, box.y, box.w, box.h)]
         
     | 
| 180 | 
         
            +
                                            for box in boxes
         
     | 
| 181 | 
         
            +
                                        ]))
         
     | 
| 182 | 
         
            +
                                    image_features = torch.stack(image_features)
         
     | 
| 183 | 
         
            +
                                    image_features = image_features.view(-1, image_features.shape[-1])
         
     | 
| 184 | 
         
            +
                        logits_per_image, logits_per_text, image_features, text_features = self.call_model(model, images_t, text_tensor, image_features=image_features, text_features=text_features, boxes=boxes, image_pth=image_pth)
         
     | 
| 185 | 
         
            +
                        all_logits_per_image.append(logits_per_image) 
         
     | 
| 186 | 
         
            +
                        all_logits_per_text.append(logits_per_text) 
         
     | 
| 187 | 
         
            +
                        if self.cache_path is not None and image_name is not None and image_features is not None:
         
     | 
| 188 | 
         
            +
                            image_features = image_features.view(len(box_representation_methods), len(boxes), image_features.shape[-1]) 
         
     | 
| 189 | 
         
            +
                            if not os.path.exists(image_feat_path):
         
     | 
| 190 | 
         
            +
                                os.makedirs(image_feat_path)
         
     | 
| 191 | 
         
            +
                            for i in range(image_features.shape[0]):
         
     | 
| 192 | 
         
            +
                                method_name = box_representation_methods[i]
         
     | 
| 193 | 
         
            +
                                if not os.path.exists(os.path.join(image_feat_path, method_name+".pt")):
         
     | 
| 194 | 
         
            +
                                    image_features_dict = {(box.x, box.y, box.w, box.h): image_features[i,j,:].cpu() for j, box in enumerate(boxes)} 
         
     | 
| 195 | 
         
            +
                                    torch.save(image_features_dict, os.path.join(image_feat_path, method_name+".pt")) 
         
     | 
| 196 | 
         
            +
                        if self.cache_path is not None and not os.path.exists(os.path.join(text_cache_path, caption_hash+".pt")) and text_features is not None:
         
     | 
| 197 | 
         
            +
                            assert text_features.shape[0] == 1
         
     | 
| 198 | 
         
            +
                            if not os.path.exists(text_cache_path):
         
     | 
| 199 | 
         
            +
                                os.makedirs(text_cache_path)
         
     | 
| 200 | 
         
            +
                            torch.save(text_features.cpu(), os.path.join(text_cache_path, caption_hash+".pt"))
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    all_logits_per_image = torch.stack(all_logits_per_image).sum(0)
         
     | 
| 203 | 
         
            +
                    all_logits_per_text = torch.stack(all_logits_per_text).sum(0)
         
     | 
| 204 | 
         
            +
                    if self.method_aggregator == "max":
         
     | 
| 205 | 
         
            +
                        all_logits_per_text = all_logits_per_text.view(-1, len(boxes)).max(dim=0, keepdim=True)[0]
         
     | 
| 206 | 
         
            +
                    elif self.method_aggregator == "sum":
         
     | 
| 207 | 
         
            +
                        all_logits_per_text = all_logits_per_text.view(-1, len(boxes)).sum(dim=0, keepdim=True) 
         
     | 
| 208 | 
         
            +
                    return all_logits_per_text.view(-1) 
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
            class ClipExecutor(Executor):
         
     | 
| 211 | 
         
            +
                def __init__(self, clip_model: str = "ViT-B/32", device: str = "cpu", box_representation_method: str = "crop", method_aggregator: str = "max", enlarge_boxes: int = 0, expand_position_embedding: bool = False, square_size: bool = False, blur_std_dev: int = 100, cache_path: str = None, input_file: str = None, clip_type: str=None) -> None:
         
     | 
| 212 | 
         
            +
                    super().__init__(device, box_representation_method, method_aggregator, enlarge_boxes, expand_position_embedding, square_size, blur_std_dev, cache_path)
         
     | 
| 213 | 
         
            +
                    self.clip_models = clip_model.split(",")
         
     | 
| 214 | 
         
            +
                    self.model_names = [model_name.replace("/", "_") for model_name in self.clip_models]
         
     | 
| 215 | 
         
            +
                    self.models = []
         
     | 
| 216 | 
         
            +
                    self.preprocesses = []
         
     | 
| 217 | 
         
            +
                    self.data_name = input_file.split('/')[-1].split('.')[0]
         
     | 
| 218 | 
         
            +
                    self.mask_path = None
         
     | 
| 219 | 
         
            +
                    self.clip_type = clip_type
         
     | 
| 220 | 
         
            +
                    if self.cache_path is not None:
         
     | 
| 221 | 
         
            +
                        self.mask_path = os.path.join(self.cache_path, "refcoco_val", 'det_masks')
         
     | 
| 222 | 
         
            +
                    sam_checkpoint = "./ckpt/sam_vit_h_4b8939.pth"
         
     | 
| 223 | 
         
            +
                    model_type = "vit_h"
         
     | 
| 224 | 
         
            +
                    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
         
     | 
| 225 | 
         
            +
                    sam.to(device=device)
         
     | 
| 226 | 
         
            +
                    self.predictor = SamPredictor(sam)
         
     | 
| 227 | 
         
            +
                    for model_name in self.clip_models:
         
     | 
| 228 | 
         
            +
                        if 'aclip' in self.clip_type:#using alpha-clip
         
     | 
| 229 | 
         
            +
                            self.mask_transform = transforms.Compose([
         
     | 
| 230 | 
         
            +
                                transforms.ToTensor(), 
         
     | 
| 231 | 
         
            +
                                transforms.Resize((224, 224)),
         
     | 
| 232 | 
         
            +
                                transforms.Normalize(0.5, 0.26)
         
     | 
| 233 | 
         
            +
                            ]) 
         
     | 
| 234 | 
         
            +
                            if model_name == 'ViT-B/16':
         
     | 
| 235 | 
         
            +
                                model, preprocess = alpha_clip.load("ViT-B/16", alpha_vision_ckpt_pth="./ckpt/grit1m/clip_b16_grit+mim_fultune_4xe.pth", device=device)
         
     | 
| 236 | 
         
            +
                            elif model_name == 'ViT-L/14':
         
     | 
| 237 | 
         
            +
                                model, preprocess = alpha_clip.load("ViT-L/14", alpha_vision_ckpt_pth="./ckpt/grit1m/clip_l14_grit+mim_fultune_6xe.pth", device=device) 
         
     | 
| 238 | 
         
            +
                           
         
     | 
| 239 | 
         
            +
                        else: model, preprocess = clip.load(model_name, device=device, jit=False)
         
     | 
| 240 | 
         
            +
                        self.models.append(model)
         
     | 
| 241 | 
         
            +
                        if self.square_size:
         
     | 
| 242 | 
         
            +
                            print("Square size!")
         
     | 
| 243 | 
         
            +
                            preprocess.transforms[0] = transforms.Resize((model.visual.input_resolution, model.visual.input_resolution), interpolation=transforms.InterpolationMode.BICUBIC)
         
     | 
| 244 | 
         
            +
                        self.preprocesses.append(preprocess)
         
     | 
| 245 | 
         
            +
                    self.models = torch.nn.ModuleList(self.models)
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                def preprocess_text(self, text: str) -> torch.Tensor:
         
     | 
| 248 | 
         
            +
                    if "aclip" in self.box_representation_method:
         
     | 
| 249 | 
         
            +
                        return alpha_clip.tokenize([text.lower()])
         
     | 
| 250 | 
         
            +
                    if "shade" in self.box_representation_method:
         
     | 
| 251 | 
         
            +
                        return clip.tokenize([text.lower()+" is in red color."])
         
     | 
| 252 | 
         
            +
                    return clip.tokenize(["a photo of "+text.lower()])
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                def call_model(self, model: torch.nn.Module, images: torch.Tensor, text: torch.Tensor, image_features: torch.Tensor = None, text_features: torch.Tensor = None, boxes=None, image_pth=None) -> torch.Tensor:
         
     | 
| 255 | 
         
            +
                    if image_features is None:
         
     | 
| 256 | 
         
            +
                        print('computing image features')
         
     | 
| 257 | 
         
            +
                        if 'aclip' not in self.clip_type:
         
     | 
| 258 | 
         
            +
                            image_features = model.encode_image(images)
         
     | 
| 259 | 
         
            +
                            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
         
     | 
| 260 | 
         
            +
                        else:
         
     | 
| 261 | 
         
            +
                            image_features = []
         
     | 
| 262 | 
         
            +
                            if 'full' in self.box_representation_method:
         
     | 
| 263 | 
         
            +
                                aclip_images = images[:len(boxes)]
         
     | 
| 264 | 
         
            +
                                alphas = []
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                                if os.path.exists(os.path.join(self.image_feat_path, 'full.pt')):
         
     | 
| 267 | 
         
            +
                                    features = torch.load(os.path.join(self.image_feat_path, 'full.pt'), map_location=self.device)
         
     | 
| 268 | 
         
            +
                                    aclip_image_features = torch.stack([
         
     | 
| 269 | 
         
            +
                                        features[(box.x, box.y, box.w, box.h)]
         
     | 
| 270 | 
         
            +
                                        for box in boxes
         
     | 
| 271 | 
         
            +
                                    ])
         
     | 
| 272 | 
         
            +
                                else:
         
     | 
| 273 | 
         
            +
                                    for i in range(len(self.all_masks)):
         
     | 
| 274 | 
         
            +
                                        binary_mask = self.all_masks[i] 
         
     | 
| 275 | 
         
            +
                                        alpha = self.mask_transform((binary_mask * 255).astype(np.uint8)) 
         
     | 
| 276 | 
         
            +
                                        alpha = alpha.half().cuda().unsqueeze(dim=0)
         
     | 
| 277 | 
         
            +
                                        alphas.append(alpha)
         
     | 
| 278 | 
         
            +
                                    
         
     | 
| 279 | 
         
            +
                                    alphas = torch.cat(alphas, dim=0)
         
     | 
| 280 | 
         
            +
                                    aclip_images = aclip_images.half()
         
     | 
| 281 | 
         
            +
                                    aclip_image_features = model.visual(aclip_images, alphas) # using alpha channels
         
     | 
| 282 | 
         
            +
                                images = images[len(boxes):]
         
     | 
| 283 | 
         
            +
                                image_features.append(aclip_image_features)
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                            if 'blur' in self.box_representation_method:
         
     | 
| 286 | 
         
            +
                                if os.path.exists(os.path.join(self.image_feat_path, 'blur.pt')):
         
     | 
| 287 | 
         
            +
                                    features = torch.load(os.path.join(self.image_feat_path, 'blur.pt'), map_location=self.device)
         
     | 
| 288 | 
         
            +
                                    ablur_images_features = torch.stack([
         
     | 
| 289 | 
         
            +
                                        features[(box.x, box.y, box.w, box.h)]
         
     | 
| 290 | 
         
            +
                                        for box in boxes
         
     | 
| 291 | 
         
            +
                                    ])
         
     | 
| 292 | 
         
            +
                                else:
         
     | 
| 293 | 
         
            +
                                    ablur_images = images[:len(boxes)]
         
     | 
| 294 | 
         
            +
                                    alphas = []
         
     | 
| 295 | 
         
            +
                                    for i in range(len(self.all_masks)):
         
     | 
| 296 | 
         
            +
                                        binary_mask = self.all_masks[i]
         
     | 
| 297 | 
         
            +
                                        alpha = self.mask_transform((binary_mask * 255).astype(np.uint8))
         
     | 
| 298 | 
         
            +
                                        alpha = alpha.half().cuda().unsqueeze(dim=0)
         
     | 
| 299 | 
         
            +
                                        alphas.append(alpha)
         
     | 
| 300 | 
         
            +
                                    alphas = torch.cat(alphas, dim=0)
         
     | 
| 301 | 
         
            +
                                    ablur_images = ablur_images.half()
         
     | 
| 302 | 
         
            +
                                    ablur_images_features = model.visual(ablur_images, alphas)
         
     | 
| 303 | 
         
            +
                                images = images[len(boxes):]
         
     | 
| 304 | 
         
            +
                                image_features.append(ablur_images_features)
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                            if 'gray' in self.box_representation_method:
         
     | 
| 307 | 
         
            +
                                if os.path.exists(os.path.join(self.image_feat_path, 'gray.pt')):
         
     | 
| 308 | 
         
            +
                                    features = torch.load(os.path.join(self.image_feat_path, 'gray.pt'), map_location=self.device)
         
     | 
| 309 | 
         
            +
                                    gray_images_features = torch.stack([
         
     | 
| 310 | 
         
            +
                                        features[(box.x, box.y, box.w, box.h)]
         
     | 
| 311 | 
         
            +
                                        for box in boxes
         
     | 
| 312 | 
         
            +
                                    ])
         
     | 
| 313 | 
         
            +
                                else:
         
     | 
| 314 | 
         
            +
                                    gray_images = images[:len(boxes)]
         
     | 
| 315 | 
         
            +
                                    alphas = []
         
     | 
| 316 | 
         
            +
                                    for i in range(len(self.all_masks)):
         
     | 
| 317 | 
         
            +
                                        binary_mask = self.all_masks[i]
         
     | 
| 318 | 
         
            +
                                        alpha = self.mask_transform((binary_mask * 255).astype(np.uint8))
         
     | 
| 319 | 
         
            +
                                        alpha = alpha.half().cuda().unsqueeze(dim=0)
         
     | 
| 320 | 
         
            +
                                        alphas.append(alpha)
         
     | 
| 321 | 
         
            +
                                    alphas = torch.cat(alphas, dim=0)
         
     | 
| 322 | 
         
            +
                                    gray_images = gray_images.half()
         
     | 
| 323 | 
         
            +
                                    gray_images_features = model.visual(gray_images, alphas)
         
     | 
| 324 | 
         
            +
                                images = images[len(boxes):]
         
     | 
| 325 | 
         
            +
                                image_features.append(gray_images_features)
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                            image_features = torch.cat(image_features, dim=0)
         
     | 
| 329 | 
         
            +
                            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
         
     | 
| 330 | 
         
            +
                            
         
     | 
| 331 | 
         
            +
                    if text_features is None:
         
     | 
| 332 | 
         
            +
                        print('computing text features')
         
     | 
| 333 | 
         
            +
                        text_features = model.encode_text(text)
         
     | 
| 334 | 
         
            +
                        # normalized features
         
     | 
| 335 | 
         
            +
                        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                    # cosine similarity as logits
         
     | 
| 338 | 
         
            +
                    logit_scale = model.logit_scale.exp()
         
     | 
| 339 | 
         
            +
                    logits_per_image = logit_scale * image_features @ text_features.t()
         
     | 
| 340 | 
         
            +
                    logits_per_text = logits_per_image.t()
         
     | 
| 341 | 
         
            +
                    return logits_per_image, logits_per_text, image_features, text_features
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                def __call__(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth=None) -> torch.Tensor:
         
     | 
| 344 | 
         
            +
                    if self.expand_position_embedding: 
         
     | 
| 345 | 
         
            +
                        original_preprocesses = self.preprocesses
         
     | 
| 346 | 
         
            +
                        new_preprocesses = []
         
     | 
| 347 | 
         
            +
                        original_position_embeddings = []
         
     | 
| 348 | 
         
            +
                        for model_name, model, preprocess in zip(self.clip_models, self.models, self.preprocesses):
         
     | 
| 349 | 
         
            +
                            if "RN" in model_name:
         
     | 
| 350 | 
         
            +
                                model_spatial_dim = int((model.visual.attnpool.positional_embedding.shape[0]-1)**0.5)
         
     | 
| 351 | 
         
            +
                                patch_size = model.visual.input_resolution // model_spatial_dim
         
     | 
| 352 | 
         
            +
                                original_positional_embedding = model.visual.attnpool.positional_embedding.clone()
         
     | 
| 353 | 
         
            +
                                model.visual.attnpool.positional_embedding = torch.nn.Parameter(torch.nn.functional.interpolate(
         
     | 
| 354 | 
         
            +
                                    model.visual.attnpool.positional_embedding[1:,:].permute(1, 0).view(1, -1, model_spatial_dim, model_spatial_dim),
         
     | 
| 355 | 
         
            +
                                    size=(image.height // patch_size, image.width // patch_size),
         
     | 
| 356 | 
         
            +
                                    mode='bicubic',
         
     | 
| 357 | 
         
            +
                                    align_corners=False
         
     | 
| 358 | 
         
            +
                                ).squeeze(0).permute(1, 2, 0).view(-1, original_positional_embedding.shape[-1]))
         
     | 
| 359 | 
         
            +
                                model.visual.attnpool.positional_embedding = torch.nn.Parameter(torch.cat((
         
     | 
| 360 | 
         
            +
                                    original_positional_embedding[:1,:],
         
     | 
| 361 | 
         
            +
                                    model.visual.attnpool.positional_embedding
         
     | 
| 362 | 
         
            +
                                ), dim=0))
         
     | 
| 363 | 
         
            +
                                transform = transforms.Compose([
         
     | 
| 364 | 
         
            +
                                    transforms.Resize(((image.height // patch_size)*patch_size, (image.width // patch_size)*patch_size), interpolation=Image.BICUBIC),
         
     | 
| 365 | 
         
            +
                                    lambda image: image.convert("RGB"),
         
     | 
| 366 | 
         
            +
                                    transforms.ToTensor(),
         
     | 
| 367 | 
         
            +
                                    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
         
     | 
| 368 | 
         
            +
                                ])
         
     | 
| 369 | 
         
            +
                            else:
         
     | 
| 370 | 
         
            +
                                model_spatial_dim = int((model.visual.positional_embedding.shape[0]-1)**0.5)
         
     | 
| 371 | 
         
            +
                                patch_size = model.visual.input_resolution // model_spatial_dim
         
     | 
| 372 | 
         
            +
                                original_positional_embedding = model.visual.positional_embedding.clone()
         
     | 
| 373 | 
         
            +
                                model.visual.positional_embedding = torch.nn.Parameter(torch.nn.functional.interpolate(
         
     | 
| 374 | 
         
            +
                                    model.visual.positional_embedding[1:,:].permute(1, 0).view(1, -1, model_spatial_dim, model_spatial_dim),
         
     | 
| 375 | 
         
            +
                                    size=(image.height // patch_size, image.width // patch_size),
         
     | 
| 376 | 
         
            +
                                    mode='bicubic',
         
     | 
| 377 | 
         
            +
                                    align_corners=False
         
     | 
| 378 | 
         
            +
                                ).squeeze(0).permute(1, 2, 0).view(-1, original_positional_embedding.shape[-1]))
         
     | 
| 379 | 
         
            +
                                model.visual.positional_embedding = torch.nn.Parameter(torch.cat((
         
     | 
| 380 | 
         
            +
                                    original_positional_embedding[:1,:],
         
     | 
| 381 | 
         
            +
                                    model.visual.positional_embedding
         
     | 
| 382 | 
         
            +
                                ), dim=0))
         
     | 
| 383 | 
         
            +
                                transform = transforms.Compose([
         
     | 
| 384 | 
         
            +
                                    transforms.Resize(((image.height // patch_size)*patch_size, (image.width // patch_size)*patch_size), interpolation=Image.BICUBIC),
         
     | 
| 385 | 
         
            +
                                    lambda image: image.convert("RGB"),
         
     | 
| 386 | 
         
            +
                                    transforms.ToTensor(),
         
     | 
| 387 | 
         
            +
                                    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
         
     | 
| 388 | 
         
            +
                                ])
         
     | 
| 389 | 
         
            +
                            new_preprocesses.append(transform)
         
     | 
| 390 | 
         
            +
                            original_position_embeddings.append(original_positional_embedding)
         
     | 
| 391 | 
         
            +
                        self.preprocesses = new_preprocesses
         
     | 
| 392 | 
         
            +
                    result = super().__call__(caption, image, boxes, image_name, image_pth)
         
     | 
| 393 | 
         
            +
                    if self.expand_position_embedding:
         
     | 
| 394 | 
         
            +
                        self.preprocesses = original_preprocesses
         
     | 
| 395 | 
         
            +
                        for model, model_name, pos_embedding in zip(self.models, self.clip_models, original_position_embeddings):
         
     | 
| 396 | 
         
            +
                            if "RN" in model_name:
         
     | 
| 397 | 
         
            +
                                model.visual.attnpool.positional_embedding = torch.nn.Parameter(pos_embedding)
         
     | 
| 398 | 
         
            +
                            else:
         
     | 
| 399 | 
         
            +
                                model.visual.positional_embedding = torch.nn.Parameter(pos_embedding)
         
     | 
| 400 | 
         
            +
                    return result
         
     | 
| 401 | 
         
            +
             
     | 
    	
        AlphaCLIP/eval/rec_zs_test/generic_clip_pairs.py
    ADDED
    
    | 
         @@ -0,0 +1,107 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import clip
         
     | 
| 3 | 
         
            +
            import json
         
     | 
| 4 | 
         
            +
            import argparse
         
     | 
| 5 | 
         
            +
            import ruamel.yaml as yaml
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from PIL import Image
         
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            import torchvision.transforms as transforms
         
     | 
| 10 | 
         
            +
            from tqdm import tqdm
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from albef.utils import *
         
     | 
| 13 | 
         
            +
            from executor import AlbefExecutor
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            parser = argparse.ArgumentParser()
         
     | 
| 16 | 
         
            +
            parser.add_argument("--input_path", type=str, help="Path to input JSON file")
         
     | 
| 17 | 
         
            +
            parser.add_argument("--image_root", type=str, help="Path to directory containing images")
         
     | 
| 18 | 
         
            +
            parser.add_argument("--albef_path", type=str, default=None, help="Path to ALBEF model/config/etc. if the goal is to use ALBEF")
         
     | 
| 19 | 
         
            +
            parser.add_argument("--albef_itc", action="store_true", help="Use ITC output of ALBEF")
         
     | 
| 20 | 
         
            +
            parser.add_argument("--clip_model", type=str, help="CLIP model to use")
         
     | 
| 21 | 
         
            +
            parser.add_argument("--gpu", type=int, default=-1, help="Which gpu to use")
         
     | 
| 22 | 
         
            +
            parser.add_argument("--batch_size", type=int, default=32, help="Batch size for running CLIP")
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            args = parser.parse_args()
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            if args.albef_path is not None:
         
     | 
| 27 | 
         
            +
                executor = AlbefExecutor(checkpoint_path = os.path.join(args.albef_path, "checkpoint.pth"), config_path = os.path.join(args.albef_path, "config.yaml"), device = "cpu" if args.gpu < 0 else "cuda:"+str(args.gpu))
         
     | 
| 28 | 
         
            +
                model = executor.models[0]
         
     | 
| 29 | 
         
            +
                preprocess = executor.preprocesses[0]
         
     | 
| 30 | 
         
            +
                model = model.eval()
         
     | 
| 31 | 
         
            +
            else:
         
     | 
| 32 | 
         
            +
                model, preprocess = clip.load(args.clip_model, jit=False, device="cuda:"+str(args.gpu))
         
     | 
| 33 | 
         
            +
                preprocess.transforms[0] == transforms.Resize((model.visual.input_resolution, model.visual.input_resolution), transforms.InterpolationMode.BICUBIC)
         
     | 
| 34 | 
         
            +
                model = model.eval()
         
     | 
| 35 | 
         
            +
            input_file = open(args.input_path)
         
     | 
| 36 | 
         
            +
            data = json.load(input_file)
         
     | 
| 37 | 
         
            +
            input_file.close()
         
     | 
| 38 | 
         
            +
            correct = 0
         
     | 
| 39 | 
         
            +
            for i in tqdm(range(0, len(data), args.batch_size)):
         
     | 
| 40 | 
         
            +
                batch_images = []
         
     | 
| 41 | 
         
            +
                batch_text = []
         
     | 
| 42 | 
         
            +
                for datum in data[i:min(i+args.batch_size, len(data))]:
         
     | 
| 43 | 
         
            +
                    img = Image.open(os.path.join(args.image_root, datum["image_filename"])).convert('RGB')
         
     | 
| 44 | 
         
            +
                    batch_images.append(preprocess(img))
         
     | 
| 45 | 
         
            +
                    if "text2" in datum:
         
     | 
| 46 | 
         
            +
                        if args.albef_path is None:
         
     | 
| 47 | 
         
            +
                            datum["text1"] = "a photo of "+datum["text1"]
         
     | 
| 48 | 
         
            +
                            datum["text2"] = "a photo of "+datum["text2"]
         
     | 
| 49 | 
         
            +
                        batch_text.append(datum["text1"])
         
     | 
| 50 | 
         
            +
                        batch_text.append(datum["text2"])
         
     | 
| 51 | 
         
            +
                    else:
         
     | 
| 52 | 
         
            +
                        img2 = Image.open(os.path.join(args.image_root, datum["image_filename2"])).convert('RGB')
         
     | 
| 53 | 
         
            +
                        batch_images.append(preprocess(img2))
         
     | 
| 54 | 
         
            +
                        batch_text.append(datum["text1"])
         
     | 
| 55 | 
         
            +
                batch_images = torch.stack(batch_images).to("cuda:"+str(args.gpu))
         
     | 
| 56 | 
         
            +
                if args.albef_path is None:
         
     | 
| 57 | 
         
            +
                    batch_text = clip.tokenize(batch_text).to("cuda:"+str(args.gpu))
         
     | 
| 58 | 
         
            +
                else:
         
     | 
| 59 | 
         
            +
                    modified_text = [pre_caption(txt, executor.max_words) for txt in batch_text]
         
     | 
| 60 | 
         
            +
                    batch_text = executor.tokenizer(modified_text, padding='longest', return_tensors="pt")
         
     | 
| 61 | 
         
            +
                    for key in batch_text:
         
     | 
| 62 | 
         
            +
                        batch_text[key] = batch_text[key].to(batch_images.device)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                with torch.no_grad():
         
     | 
| 65 | 
         
            +
                    if args.albef_path is None:
         
     | 
| 66 | 
         
            +
                        logits_per_image, logits_per_text = model(batch_images, batch_text)
         
     | 
| 67 | 
         
            +
                    else:
         
     | 
| 68 | 
         
            +
                        if not args.albef_itc:
         
     | 
| 69 | 
         
            +
                            if batch_images.shape[0]*2 == batch_text.input_ids.shape[0]:
         
     | 
| 70 | 
         
            +
                                batch_images = batch_images.unsqueeze(1).repeat(1, 2, 1, 1, 1).view(batch_images.shape[0]*2, batch_images.shape[1], batch_images.shape[2], batch_images.shape[3])
         
     | 
| 71 | 
         
            +
                            else:
         
     | 
| 72 | 
         
            +
                                assert batch_images.shape[0] ==2*batch_text.input_ids.shape[0]
         
     | 
| 73 | 
         
            +
                                batch_text.input_ids = batch_text.input_ids.unsqueeze(1).repeat(1, 2, 1).view(batch_images.shape[0], -1)
         
     | 
| 74 | 
         
            +
                                batch_text.attention_mask = batch_text.attention_mask.unsqueeze(1).repeat(1, 2, 1).view(batch_images.shape[0], -1)
         
     | 
| 75 | 
         
            +
                            image_embeds = model.visual_encoder(batch_images)
         
     | 
| 76 | 
         
            +
                            image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(batch_images.device)
         
     | 
| 77 | 
         
            +
                            output = model.text_encoder(
         
     | 
| 78 | 
         
            +
                                batch_text.input_ids,
         
     | 
| 79 | 
         
            +
                                attention_mask = batch_text.attention_mask,
         
     | 
| 80 | 
         
            +
                                encoder_hidden_states = image_embeds,
         
     | 
| 81 | 
         
            +
                                encoder_attention_mask = image_atts,      
         
     | 
| 82 | 
         
            +
                                return_dict = True,
         
     | 
| 83 | 
         
            +
                            )
         
     | 
| 84 | 
         
            +
                            vl_embeddings = output.last_hidden_state[:,0,:]
         
     | 
| 85 | 
         
            +
                            vl_output = model.itm_head(vl_embeddings)
         
     | 
| 86 | 
         
            +
                            logits_per_image = vl_output[:,1:2].view(-1, 2)
         
     | 
| 87 | 
         
            +
                        else:
         
     | 
| 88 | 
         
            +
                            image_embeds = model.visual_encoder(batch_images)
         
     | 
| 89 | 
         
            +
                            image_feat = torch.nn.functional.normalize(model.vision_proj(image_embeds[:,0,:]),dim=-1) 
         
     | 
| 90 | 
         
            +
                            text_output = model.text_encoder(batch_text.input_ids, attention_mask = batch_text.attention_mask,                 
         
     | 
| 91 | 
         
            +
                                                             return_dict = True, mode = 'text')            
         
     | 
| 92 | 
         
            +
                            text_embeds = text_output.last_hidden_state
         
     | 
| 93 | 
         
            +
                            text_feat = torch.nn.functional.normalize(model.text_proj(text_embeds[:,0,:]),dim=-1)     
         
     | 
| 94 | 
         
            +
                            sim = image_feat@text_feat.t()/model.temp
         
     | 
| 95 | 
         
            +
                            logits_per_image = sim
         
     | 
| 96 | 
         
            +
                if args.albef_path is None or args.albef_itc:
         
     | 
| 97 | 
         
            +
                    if logits_per_image.shape[0]*2 == logits_per_image.shape[1]:
         
     | 
| 98 | 
         
            +
                        for j in range(logits_per_image.shape[0]):
         
     | 
| 99 | 
         
            +
                            correct += 1 if logits_per_image[j,2*j].item() > logits_per_image[j,2*j+1].item() else 0
         
     | 
| 100 | 
         
            +
                    else:
         
     | 
| 101 | 
         
            +
                        assert logits_per_image.shape[0] == 2*logits_per_image.shape[1]
         
     | 
| 102 | 
         
            +
                        for j in range(logits_per_image.shape[1]):
         
     | 
| 103 | 
         
            +
                            correct += 1 if logits_per_image[2*j,j].item() > logits_per_image[2*j+1,j].item() else 0
         
     | 
| 104 | 
         
            +
                else:
         
     | 
| 105 | 
         
            +
                    correct += (logits_per_image[:,0] > logits_per_image[:,1]).long().sum().item()
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            print("Accuracy:", correct/len(data))
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/heuristics.py
    ADDED
    
    | 
         @@ -0,0 +1,68 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """Heuristic rules used to extract and execute entity parses."""
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from typing import Callable, List, NamedTuple
         
     | 
| 4 | 
         
            +
            from argparse import Namespace
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            class RelHeuristic(NamedTuple):
         
     | 
| 9 | 
         
            +
                keywords: List[str]
         
     | 
| 10 | 
         
            +
                callback: Callable[["Environment"], np.ndarray]
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class Heuristics:
         
     | 
| 14 | 
         
            +
                """A class defining heuristics that can be enabled/disabled."""
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                RELATIONS = [
         
     | 
| 17 | 
         
            +
                    RelHeuristic(["left", "west"], lambda env: env.left_of()),
         
     | 
| 18 | 
         
            +
                    RelHeuristic(["right", "east"], lambda env: env.right_of()),
         
     | 
| 19 | 
         
            +
                    RelHeuristic(["above", "north", "top", "back", "behind"], lambda env: env.above()),
         
     | 
| 20 | 
         
            +
                    RelHeuristic(["below", "south", "under", "front"], lambda env: env.below()),
         
     | 
| 21 | 
         
            +
                    RelHeuristic(["bigger", "larger", "closer"], lambda env: env.bigger_than()),
         
     | 
| 22 | 
         
            +
                    RelHeuristic(["smaller", "tinier", "further"], lambda env: env.smaller_than()),
         
     | 
| 23 | 
         
            +
                    RelHeuristic(["inside", "within", "contained"], lambda env: env.within()),
         
     | 
| 24 | 
         
            +
                ]
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                TERNARY_RELATIONS = [
         
     | 
| 27 | 
         
            +
                    RelHeuristic(["between"], lambda env: env.between()),
         
     | 
| 28 | 
         
            +
                ]
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                SUPERLATIVES = [
         
     | 
| 31 | 
         
            +
                    RelHeuristic(["left", "west", "leftmost", "western"], lambda env: env.left_of()),
         
     | 
| 32 | 
         
            +
                    RelHeuristic(["right", "rightmost", "east", "eastern"], lambda env: env.right_of()),
         
     | 
| 33 | 
         
            +
                    RelHeuristic(["above", "north", "top"], lambda env: env.above()),
         
     | 
| 34 | 
         
            +
                    RelHeuristic(["below", "south", "underneath", "front"], lambda env: env.below()),
         
     | 
| 35 | 
         
            +
                    RelHeuristic(["bigger", "biggest", "larger", "largest", "closer", "closest"], lambda env: env.bigger_than()),
         
     | 
| 36 | 
         
            +
                    RelHeuristic(["smaller", "smallest", "tinier", "tiniest", "further", "furthest"], lambda env: env.smaller_than()),
         
     | 
| 37 | 
         
            +
                ]
         
     | 
| 38 | 
         
            +
                OPPOSITES = {0: 1, 1: 0, 2: 3, 3: 2, 4: 5, 5: 4}
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                NULL_KEYWORDS = ["part", "image", "side", "picture", "half", "region", "section"]
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                EMPTY = []
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                def __init__(self, args: Namespace = None):
         
     | 
| 45 | 
         
            +
                    self.enable_relations = not args or not args.no_rel
         
     | 
| 46 | 
         
            +
                    self.enable_superlatives = not args or not args.no_sup
         
     | 
| 47 | 
         
            +
                    self.enable_nulls = not args or not args.no_null
         
     | 
| 48 | 
         
            +
                    self.enable_ternary = not args or args.ternary
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                @property
         
     | 
| 51 | 
         
            +
                def relations(self) -> List[RelHeuristic]:
         
     | 
| 52 | 
         
            +
                    return self.RELATIONS if self.enable_relations else self.EMPTY
         
     | 
| 53 | 
         
            +
                
         
     | 
| 54 | 
         
            +
                @property
         
     | 
| 55 | 
         
            +
                def ternary_relations(self) -> List[RelHeuristic]:
         
     | 
| 56 | 
         
            +
                    return self.TERNARY_RELATIONS if self.enable_ternary else self.EMPTY
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                @property
         
     | 
| 59 | 
         
            +
                def superlatives(self) -> List[RelHeuristic]:
         
     | 
| 60 | 
         
            +
                    return self.SUPERLATIVES if self.enable_superlatives else self.EMPTY
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                @property
         
     | 
| 63 | 
         
            +
                def opposites(self):
         
     | 
| 64 | 
         
            +
                    return self.OPPOSITES
         
     | 
| 65 | 
         
            +
                
         
     | 
| 66 | 
         
            +
                @property
         
     | 
| 67 | 
         
            +
                def null_keywords(self) -> List[str]:
         
     | 
| 68 | 
         
            +
                    return self.NULL_KEYWORDS if self.enable_nulls else self.EMPTY
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/interpreter.py
    ADDED
    
    | 
         @@ -0,0 +1,212 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import NamedTuple, List, Callable
         
     | 
| 2 | 
         
            +
            import sys
         
     | 
| 3 | 
         
            +
            import re
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            from numpy.linalg import norm
         
     | 
| 7 | 
         
            +
            from itertools import product, groupby
         
     | 
| 8 | 
         
            +
            from PIL import Image
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            # Do two line segments intersect? Copied from
         
     | 
| 12 | 
         
            +
            # https://stackoverflow.com/questions/3838329/how-can-i-check-if-two-segments-intersect
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            def ccw(A, B, C):
         
     | 
| 16 | 
         
            +
                return (C.y - A.y) * (B.x - A.x) > (B.y - A.y) * (C.x - A.x)
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def intersect(A, B, C, D):
         
     | 
| 20 | 
         
            +
                """Do line segments AB and CD intersect?"""
         
     | 
| 21 | 
         
            +
                return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            class Box(NamedTuple):
         
     | 
| 25 | 
         
            +
                x: int
         
     | 
| 26 | 
         
            +
                y: int
         
     | 
| 27 | 
         
            +
                w: int = 0
         
     | 
| 28 | 
         
            +
                h: int = 0
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                @property
         
     | 
| 31 | 
         
            +
                def left(self):
         
     | 
| 32 | 
         
            +
                    return self.x
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                @property
         
     | 
| 35 | 
         
            +
                def right(self):
         
     | 
| 36 | 
         
            +
                    return self.x + self.w
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                @property
         
     | 
| 39 | 
         
            +
                def top(self):
         
     | 
| 40 | 
         
            +
                    return self.y
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                @property
         
     | 
| 43 | 
         
            +
                def bottom(self):
         
     | 
| 44 | 
         
            +
                    return self.y + self.h
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                @property
         
     | 
| 47 | 
         
            +
                def center(self):
         
     | 
| 48 | 
         
            +
                    return Box(self.x + self.w // 2, self.y + self.h // 2)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                def corners(self):
         
     | 
| 51 | 
         
            +
                    yield Box(self.x, self.y)
         
     | 
| 52 | 
         
            +
                    yield Box(self.x + self.w, self.y)
         
     | 
| 53 | 
         
            +
                    yield Box(self.x + self.w, self.y + self.h)
         
     | 
| 54 | 
         
            +
                    yield Box(self.x, self.y + self.h)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                @property
         
     | 
| 57 | 
         
            +
                def area(self):
         
     | 
| 58 | 
         
            +
                    return self.w * self.h
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                def intersect(self, other: "Box") -> "Box":
         
     | 
| 61 | 
         
            +
                    x1 = max(self.x, other.x)
         
     | 
| 62 | 
         
            +
                    x2 = max(x1, min(self.x+self.w, other.x+other.w))
         
     | 
| 63 | 
         
            +
                    y1 = max(self.y, other.y)
         
     | 
| 64 | 
         
            +
                    y2 = max(y1, min(self.y+self.h, other.y+other.h))
         
     | 
| 65 | 
         
            +
                    return Box(x=x1, y=y1, w=x2-x1, h=y2-y1)
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                def min_bounding(self, other: "Box") -> "Box":
         
     | 
| 68 | 
         
            +
                    corners = list(self.corners())
         
     | 
| 69 | 
         
            +
                    corners.extend(other.corners())
         
     | 
| 70 | 
         
            +
                    min_x = min_y = float("inf")
         
     | 
| 71 | 
         
            +
                    max_x = max_y = -float("inf")
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    for item in corners:
         
     | 
| 74 | 
         
            +
                        min_x = min(min_x, item.x)
         
     | 
| 75 | 
         
            +
                        min_y = min(min_y, item.y)
         
     | 
| 76 | 
         
            +
                        max_x = max(max_x, item.x)
         
     | 
| 77 | 
         
            +
                        max_y = max(max_y, item.y)
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                    return Box(min_x, min_y, max_x - min_x, max_y - min_y)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                def expand(self, growth: float = .1) -> "Box":
         
     | 
| 82 | 
         
            +
                    factor = 1 + growth
         
     | 
| 83 | 
         
            +
                    w = factor * self.w
         
     | 
| 84 | 
         
            +
                    h = factor * self.h
         
     | 
| 85 | 
         
            +
                    return Box(min_x - (w - self.w) / 2, min_y - (h - self.h) / 2, w, h)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
            def iou(box1, box2):
         
     | 
| 89 | 
         
            +
                x1 = max(box1.x, box2.x)
         
     | 
| 90 | 
         
            +
                x2 = max(x1, min(box1.x+box1.w, box2.x+box2.w))
         
     | 
| 91 | 
         
            +
                y1 = max(box1.y, box2.y)
         
     | 
| 92 | 
         
            +
                y2 = max(y1, min(box1.y+box1.h, box2.y+box2.h))
         
     | 
| 93 | 
         
            +
                intersection = Box(x=x1, y=y1, w=x2-x1, h=y2-y1)
         
     | 
| 94 | 
         
            +
                intersection_area = intersection.area
         
     | 
| 95 | 
         
            +
                union_area = box1.area+box2.area-intersection_area
         
     | 
| 96 | 
         
            +
                return intersection_area / union_area
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
            def all_equal(iterable):
         
     | 
| 100 | 
         
            +
                """Are all elements the same?"""
         
     | 
| 101 | 
         
            +
                g = groupby(iterable)
         
     | 
| 102 | 
         
            +
                return next(g, True) and not next(g, False)
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
            class spatial:
         
     | 
| 106 | 
         
            +
                """A decorator that converts a predicate over boxes to a function that returns a tensor over all boxes."""
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                def __init__(self, arity: int = 2, enforce_antisymmetry: bool = False):
         
     | 
| 109 | 
         
            +
                    self.arity = arity
         
     | 
| 110 | 
         
            +
                    self.enforce_antisymmetry = enforce_antisymmetry  # Zero out any entries where two boxes are the same.
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                def __call__(self, predicate: Callable[[Box], float]) -> Callable[["Environment"], np.ndarray]:
         
     | 
| 113 | 
         
            +
                    def _rel(env):
         
     | 
| 114 | 
         
            +
                        n_boxes = len(env.boxes)
         
     | 
| 115 | 
         
            +
                        tensor = np.empty([n_boxes for _ in range(self.arity)])
         
     | 
| 116 | 
         
            +
                        enum_boxes = list(enumerate(env.boxes))
         
     | 
| 117 | 
         
            +
                        for pairs in product(*[enum_boxes for _ in range(self.arity)]):
         
     | 
| 118 | 
         
            +
                            indices, boxes = zip(*pairs)
         
     | 
| 119 | 
         
            +
                            if self.enforce_antisymmetry and len(set(indices)) < len(indices):
         
     | 
| 120 | 
         
            +
                                tensor[indices] = 0.
         
     | 
| 121 | 
         
            +
                            else:
         
     | 
| 122 | 
         
            +
                                tensor[indices] = predicate(*boxes)
         
     | 
| 123 | 
         
            +
                        return tensor
         
     | 
| 124 | 
         
            +
                    return _rel
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
            class Environment:
         
     | 
| 128 | 
         
            +
                def __init__(self, image: Image, boxes: List[Box], executor: "Executor" = None, freeform_boxes: bool = False, image_name: str = None, image_pth: str=None):
         
     | 
| 129 | 
         
            +
                    self.image = image
         
     | 
| 130 | 
         
            +
                    self.boxes = boxes
         
     | 
| 131 | 
         
            +
                    self.executor = executor  # An object or callback that can query CLIP with captions/images.
         
     | 
| 132 | 
         
            +
                    self.freeform_boxes = freeform_boxes
         
     | 
| 133 | 
         
            +
                    self.image_name = image_name
         
     | 
| 134 | 
         
            +
                    self.image_pth=image_pth
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                def uniform(self) -> np.ndarray:
         
     | 
| 137 | 
         
            +
                    n_boxes = len(self.boxes)
         
     | 
| 138 | 
         
            +
                    return 1 / n_boxes * np.ones(n_boxes)
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                def filter(self,
         
     | 
| 141 | 
         
            +
                           caption: str,
         
     | 
| 142 | 
         
            +
                           temperature: float = 1.,
         
     | 
| 143 | 
         
            +
                           area_threshold: float = 0.0,
         
     | 
| 144 | 
         
            +
                           softmax: bool = False,
         
     | 
| 145 | 
         
            +
                           expand: float = None
         
     | 
| 146 | 
         
            +
                          ) -> np.ndarray:
         
     | 
| 147 | 
         
            +
                    """Return a new distribution reflecting the likelihood that `caption` describes the content of each box."""
         
     | 
| 148 | 
         
            +
                    area_filtered_dist = torch.from_numpy(self.filter_area(area_threshold)).to(self.executor.device)
         
     | 
| 149 | 
         
            +
                    candidate_indices = [i for i in range(len(self.boxes)) if float(area_filtered_dist[i]) > 0.0]
         
     | 
| 150 | 
         
            +
                    boxes = [self.boxes[i] for i in candidate_indices]
         
     | 
| 151 | 
         
            +
                    if len(boxes) == 0:
         
     | 
| 152 | 
         
            +
                        boxes = self.boxes
         
     | 
| 153 | 
         
            +
                        candidate_indices = list(range(len(boxes)))
         
     | 
| 154 | 
         
            +
                    if expand is not None:
         
     | 
| 155 | 
         
            +
                        boxes = [box.expand(expand) for box in boxes]
         
     | 
| 156 | 
         
            +
                    result_partial = self.executor(caption, self.image, boxes, image_name=self.image_name, image_pth=self.image_pth) 
         
     | 
| 157 | 
         
            +
                    if self.freeform_boxes:
         
     | 
| 158 | 
         
            +
                        result_partial, boxes = result_partial
         
     | 
| 159 | 
         
            +
                        self.boxes = [Box(x=boxes[i,0].item(), y=boxes[i,1].item(), w=boxes[i,2].item()-boxes[i,0].item(), h=boxes[i,3].item()-boxes[i,1].item()) for i in range(boxes.shape[0])]
         
     | 
| 160 | 
         
            +
                        candidate_indices = list(range(len(self.boxes)))
         
     | 
| 161 | 
         
            +
                    result_partial = result_partial.float()
         
     | 
| 162 | 
         
            +
                    if not softmax:
         
     | 
| 163 | 
         
            +
                        result_partial = (result_partial-result_partial.mean()) / (result_partial.std() + 1e-9)
         
     | 
| 164 | 
         
            +
                        result_partial = (temperature * result_partial).sigmoid()
         
     | 
| 165 | 
         
            +
                        result = torch.zeros((len(self.boxes))).to(result_partial.device)
         
     | 
| 166 | 
         
            +
                        result[candidate_indices] = result_partial
         
     | 
| 167 | 
         
            +
                    else:
         
     | 
| 168 | 
         
            +
                        result = torch.zeros((len(self.boxes))).to(result_partial.device)
         
     | 
| 169 | 
         
            +
                        result[candidate_indices] = result_partial.softmax(dim=-1) #softmax结果
         
     | 
| 170 | 
         
            +
                    return result.cpu().numpy()
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                def filter_area(self, area_threshold: float) -> np.ndarray:
         
     | 
| 173 | 
         
            +
                    """Return a new distribution in which all boxes whose area as a fraction of the image is less than the threshold."""
         
     | 
| 174 | 
         
            +
                    image_area = self.image.width*self.image.height
         
     | 
| 175 | 
         
            +
                    return np.array([1 if self.boxes[i].area/image_area > area_threshold else 0 for i in range(len(self.boxes))])
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                @spatial()
         
     | 
| 178 | 
         
            +
                def left_of(b1, b2):
         
     | 
| 179 | 
         
            +
                    return (b1.right+b1.left) / 2 < (b2.right+b2.left) / 2
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                @spatial()
         
     | 
| 182 | 
         
            +
                def right_of(b1, b2):
         
     | 
| 183 | 
         
            +
                    return (b1.right+b1.left) / 2 > (b2.right+b2.left) / 2
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                @spatial()
         
     | 
| 186 | 
         
            +
                def above(b1, b2):
         
     | 
| 187 | 
         
            +
                    return (b1.bottom+b1.top) < (b2.bottom+b2.top)
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                @spatial()
         
     | 
| 190 | 
         
            +
                def below(b1, b2):
         
     | 
| 191 | 
         
            +
                    return (b1.bottom+b1.top) > (b2.bottom+b2.top)
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                @spatial()
         
     | 
| 194 | 
         
            +
                def bigger_than(b1, b2):
         
     | 
| 195 | 
         
            +
                    return b1.area > b2.area
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                @spatial()
         
     | 
| 198 | 
         
            +
                def smaller_than(b1, b2):
         
     | 
| 199 | 
         
            +
                    return b1.area < b2.area
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                @spatial(enforce_antisymmetry=False)
         
     | 
| 202 | 
         
            +
                def within(box1, box2):
         
     | 
| 203 | 
         
            +
                    """Return percent of box1 inside box2."""
         
     | 
| 204 | 
         
            +
                    intersection = box1.intersect(box2)
         
     | 
| 205 | 
         
            +
                    return intersection.area / box1.area
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                @spatial(arity=3, enforce_antisymmetry=True)
         
     | 
| 208 | 
         
            +
                def between(box1, box2, box3):
         
     | 
| 209 | 
         
            +
                    """How much of box1 lies in min bounding box over box2 and box3?"""
         
     | 
| 210 | 
         
            +
                    min_bounding = box2.min_bounding(box3)
         
     | 
| 211 | 
         
            +
                    intersect = box1.intersect(min_bounding)
         
     | 
| 212 | 
         
            +
                    return intersect.area / box1.area
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/lattice.py
    ADDED
    
    | 
         @@ -0,0 +1,70 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """Implement lattice interface."""
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from overrides import overrides
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            from abc import ABCMeta, abstractmethod
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            class Lattice(metaclass=ABCMeta):
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                """Abstract base class representing a complemented lattice."""
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                @classmethod
         
     | 
| 13 | 
         
            +
                @abstractmethod
         
     | 
| 14 | 
         
            +
                def join(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray:
         
     | 
| 15 | 
         
            +
                    return NotImplemented
         
     | 
| 16 | 
         
            +
                
         
     | 
| 17 | 
         
            +
                @classmethod
         
     | 
| 18 | 
         
            +
                @abstractmethod
         
     | 
| 19 | 
         
            +
                def meet(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray:
         
     | 
| 20 | 
         
            +
                    return NotImplemented
         
     | 
| 21 | 
         
            +
                
         
     | 
| 22 | 
         
            +
                @classmethod
         
     | 
| 23 | 
         
            +
                @abstractmethod
         
     | 
| 24 | 
         
            +
                def join_reduce(cls, probs: np.ndarray) -> np.ndarray:
         
     | 
| 25 | 
         
            +
                    return NotImplemented
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                @classmethod
         
     | 
| 28 | 
         
            +
                @abstractmethod
         
     | 
| 29 | 
         
            +
                def meet_reduce(cls, probs: np.ndarray) -> np.ndarray:
         
     | 
| 30 | 
         
            +
                    return NotImplemented
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            class Product(Lattice):
         
     | 
| 34 | 
         
            +
                """Lattice where meet=prod and sum is defined accordingly.
         
     | 
| 35 | 
         
            +
                
         
     | 
| 36 | 
         
            +
                Equivalent to assuming independence, more or less.
         
     | 
| 37 | 
         
            +
                """
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                eps = 1e-9
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                @classmethod
         
     | 
| 42 | 
         
            +
                @overrides
         
     | 
| 43 | 
         
            +
                def join(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray:
         
     | 
| 44 | 
         
            +
                    return probs1 + probs2 - cls.meet(probs1, probs2)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                @classmethod
         
     | 
| 47 | 
         
            +
                @overrides
         
     | 
| 48 | 
         
            +
                def meet(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray:
         
     | 
| 49 | 
         
            +
                    return probs1 * probs2
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                @classmethod
         
     | 
| 52 | 
         
            +
                @overrides
         
     | 
| 53 | 
         
            +
                def join_reduce(cls, probs: np.ndarray) -> np.ndarray:
         
     | 
| 54 | 
         
            +
                    """Assumes disjoint events."""
         
     | 
| 55 | 
         
            +
                    # return cls.comp(cls.meet_reduce(cls.comp(probs)))
         
     | 
| 56 | 
         
            +
                    return np.sum(probs, axis=-1)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                @classmethod
         
     | 
| 59 | 
         
            +
                @overrides
         
     | 
| 60 | 
         
            +
                def meet_reduce(cls, probs: np.ndarray) -> np.ndarray:
         
     | 
| 61 | 
         
            +
                    return np.prod(probs, axis=-1)
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                @classmethod
         
     | 
| 64 | 
         
            +
                def comp(cls, probs):
         
     | 
| 65 | 
         
            +
                    return 1 - probs
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                @classmethod
         
     | 
| 68 | 
         
            +
                def normalize(cls, probs):
         
     | 
| 69 | 
         
            +
                    """Normalize a distribution by dividing by the total mass."""
         
     | 
| 70 | 
         
            +
                    return probs / np.sum(probs + cls.eps, axis=-1)
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/main.py
    ADDED
    
    | 
         @@ -0,0 +1,200 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from collections import defaultdict
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         
            +
            import argparse
         
     | 
| 4 | 
         
            +
            import os
         
     | 
| 5 | 
         
            +
            import random
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            from PIL import Image
         
     | 
| 9 | 
         
            +
            from tqdm import tqdm
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from interpreter import *
         
     | 
| 12 | 
         
            +
            from executor import *
         
     | 
| 13 | 
         
            +
            from methods import *
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            METHODS_MAP = {
         
     | 
| 16 | 
         
            +
                "baseline": Baseline,
         
     | 
| 17 | 
         
            +
                "random": Random,
         
     | 
| 18 | 
         
            +
                "parse": Parse,
         
     | 
| 19 | 
         
            +
            }
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 22 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 23 | 
         
            +
                parser.add_argument("--input_file", type=str, help="input file with expressions and annotations in jsonlines format")
         
     | 
| 24 | 
         
            +
                parser.add_argument("--image_root", type=str, help="path to images (train2014 directory of COCO)")
         
     | 
| 25 | 
         
            +
                parser.add_argument("--clip_model", type=str, default="RN50x16,ViT-B/32", help="which clip model to use (should use RN50x4, ViT-B/32, or both separated by a comma")
         
     | 
| 26 | 
         
            +
                parser.add_argument("--clip_type", type=str, default="aclip", help="which clip model to use (should use RN50x4, ViT-B/32, or both separated by a comma")
         
     | 
| 27 | 
         
            +
                parser.add_argument("--albef_path", type=str, default=None, help="to use ALBEF (instead of CLIP), specify the path to the ALBEF checkpoint")
         
     | 
| 28 | 
         
            +
                parser.add_argument("--method", type=str, default="parse", help="method to solve expressions")
         
     | 
| 29 | 
         
            +
                parser.add_argument("--box_representation_method", type=str, default="crop,blur", help="method of representing boxes as individual images (crop, blur, or both separated by a comma)")
         
     | 
| 30 | 
         
            +
                parser.add_argument("--box_method_aggregator", type=str, default="sum", help="method of combining box representation scores")
         
     | 
| 31 | 
         
            +
                parser.add_argument("--box_area_threshold", type=float, default=0.0, help="minimum area (as a proportion of image area) for a box to be considered as the answer")
         
     | 
| 32 | 
         
            +
                parser.add_argument("--output_file", type=str, default=None, help="(optional) output path to save results")
         
     | 
| 33 | 
         
            +
                parser.add_argument("--detector_file", type=str, default=None, help="(optional) file containing object detections. if not provided, the gold object boxes will be used.")
         
     | 
| 34 | 
         
            +
                parser.add_argument("--mock", action="store_true", help="(optional) mock CLIP execution.")
         
     | 
| 35 | 
         
            +
                parser.add_argument("--device", type=int, default=0, help="CUDA device to use.")
         
     | 
| 36 | 
         
            +
                parser.add_argument("--shuffle_words", action="store_true", help="If true, shuffle words in the sentence")
         
     | 
| 37 | 
         
            +
                parser.add_argument("--gradcam_alpha", type=float, nargs='+', help="alpha value to use for gradcam method")
         
     | 
| 38 | 
         
            +
                parser.add_argument("--enlarge_boxes", type=float, default=0.0, help="(optional) whether to enlarge boxes when passing them to the model")
         
     | 
| 39 | 
         
            +
                parser.add_argument("--part", type=str, default=None, help="(optional) specify how many parts to divide the dataset into and which part to run in the format NUM_PARTS,PART_NUM")
         
     | 
| 40 | 
         
            +
                parser.add_argument("--batch_size", type=int, default=1, help="number of instances to process in one model call (only supported for baseline model)")
         
     | 
| 41 | 
         
            +
                parser.add_argument("--baseline_head", action="store_true", help="For baseline, controls whether model is called on both full expression and head noun chunk of expression")
         
     | 
| 42 | 
         
            +
                parser.add_argument("--mdetr", type=str, default=None, help="to use MDETR as the executor model, specify the name of the MDETR model")
         
     | 
| 43 | 
         
            +
                parser.add_argument("--albef_block_num", type=int, default=8, help="block num for ALBEF gradcam")
         
     | 
| 44 | 
         
            +
                parser.add_argument("--albef_mode", type=str, choices=["itm", "itc"], default="itm")
         
     | 
| 45 | 
         
            +
                parser.add_argument("--expand_position_embedding",action="store_true")
         
     | 
| 46 | 
         
            +
                parser.add_argument("--gradcam_background", action="store_true")
         
     | 
| 47 | 
         
            +
                parser.add_argument("--mdetr_given_bboxes", action="store_true")
         
     | 
| 48 | 
         
            +
                parser.add_argument("--mdetr_use_token_mapping", action="store_true")
         
     | 
| 49 | 
         
            +
                parser.add_argument("--non_square_size", action="store_true")
         
     | 
| 50 | 
         
            +
                parser.add_argument("--blur_std_dev", type=int, default=100, help="standard deviation of Gaussian blur")
         
     | 
| 51 | 
         
            +
                parser.add_argument("--gradcam_ensemble_before", action="store_true", help="Average gradcam maps of different models before summing over the maps")
         
     | 
| 52 | 
         
            +
                parser.add_argument("--cache_path", type=str, default=None, help="cache features")
         
     | 
| 53 | 
         
            +
                # Arguments related to Parse method.
         
     | 
| 54 | 
         
            +
                parser.add_argument("--no_rel", action="store_true", help="Disable relation extraction.")
         
     | 
| 55 | 
         
            +
                parser.add_argument("--no_sup", action="store_true", help="Disable superlative extraction.")
         
     | 
| 56 | 
         
            +
                parser.add_argument("--no_null", action="store_true", help="Disable null keyword heuristics.")
         
     | 
| 57 | 
         
            +
                parser.add_argument("--ternary", action="store_true", help="Disable ternary relation extraction.")
         
     | 
| 58 | 
         
            +
                parser.add_argument("--baseline_threshold", type=float, default=float("inf"), help="(Parse) Threshold to use relations/superlatives.")
         
     | 
| 59 | 
         
            +
                parser.add_argument("--temperature", type=float, default=1., help="(Parse) Sigmoid temperature.")
         
     | 
| 60 | 
         
            +
                parser.add_argument("--superlative_head_only", action="store_true", help="(Parse) Superlatives only quanntify head predicate.")
         
     | 
| 61 | 
         
            +
                parser.add_argument("--sigmoid", action="store_true", help="(Parse) Use sigmoid, not softmax.")
         
     | 
| 62 | 
         
            +
                parser.add_argument("--no_possessive", action="store_true", help="(Parse) Model extraneous relations as possessive relations.")
         
     | 
| 63 | 
         
            +
                parser.add_argument("--expand_chunks", action="store_true", help="(Parse) Expand noun chunks to include descendant tokens that aren't ancestors of tokens in other chunks")
         
     | 
| 64 | 
         
            +
                parser.add_argument("--parse_no_branch", action="store_true", help="(Parse) Only do the parsing procedure if some relation/superlative keyword is in the expression")
         
     | 
| 65 | 
         
            +
                parser.add_argument("--possessive_no_expand", action="store_true", help="(Parse) Expand ent2 in possessive case")
         
     | 
| 66 | 
         
            +
                args = parser.parse_args()
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                with open(args.input_file) as f: 
         
     | 
| 69 | 
         
            +
                    lines = f.readlines()
         
     | 
| 70 | 
         
            +
                    data = [json.loads(line) for line in lines]
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                device = f"cuda:{args.device}" if torch.cuda.is_available() and args.device >= 0 else "cpu"
         
     | 
| 73 | 
         
            +
                gradcam = args.method == "gradcam"
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                executor = ClipExecutor(clip_model=args.clip_model, box_representation_method=args.box_representation_method, method_aggregator=args.box_method_aggregator, device=device, square_size=not args.non_square_size, expand_position_embedding=args.expand_position_embedding, blur_std_dev=args.blur_std_dev, cache_path=args.cache_path, input_file=args.input_file, clip_type=args.clip_type)
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                method = METHODS_MAP[args.method](args)
         
     | 
| 78 | 
         
            +
                correct_count = 0
         
     | 
| 79 | 
         
            +
                total_count = 0
         
     | 
| 80 | 
         
            +
                if args.output_file:
         
     | 
| 81 | 
         
            +
                    output_file = open(args.output_file, "w")
         
     | 
| 82 | 
         
            +
                if args.detector_file:
         
     | 
| 83 | 
         
            +
                    detector_file = open(args.detector_file)
         
     | 
| 84 | 
         
            +
                    detections_list = json.load(detector_file)
         
     | 
| 85 | 
         
            +
                    if isinstance(detections_list, dict):
         
     | 
| 86 | 
         
            +
                        detections_map = {int(image_id): detections_list[image_id] for image_id in detections_list}
         
     | 
| 87 | 
         
            +
                    else:
         
     | 
| 88 | 
         
            +
                        detections_map = defaultdict(list)
         
     | 
| 89 | 
         
            +
                        for detection in detections_list:
         
     | 
| 90 | 
         
            +
                            detections_map[detection["image_id"]].append(detection["box"])
         
     | 
| 91 | 
         
            +
                
         
     | 
| 92 | 
         
            +
                part = 0
         
     | 
| 93 | 
         
            +
                if args.part is not None: # for multi-gpu test / part-data test
         
     | 
| 94 | 
         
            +
                    num_parts = int(args.part.split(",")[0])
         
     | 
| 95 | 
         
            +
                    part = int(args.part.split(",")[1])
         
     | 
| 96 | 
         
            +
                    data = data[int(len(data)*part/num_parts):int(len(data)*(part+1)/num_parts)]
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                batch_count = 0
         
     | 
| 99 | 
         
            +
                batch_boxes = []
         
     | 
| 100 | 
         
            +
                batch_gold_boxes = []
         
     | 
| 101 | 
         
            +
                batch_gold_index = []
         
     | 
| 102 | 
         
            +
                batch_file_names = []
         
     | 
| 103 | 
         
            +
                batch_sentences = []
         
     | 
| 104 | 
         
            +
                for datum in tqdm(data):
         
     | 
| 105 | 
         
            +
                    if "coco" in datum["file_name"].lower():
         
     | 
| 106 | 
         
            +
                        file_name = "_".join(datum["file_name"].split("_")[:-1])+".jpg"
         
     | 
| 107 | 
         
            +
                    else:
         
     | 
| 108 | 
         
            +
                        file_name = datum["file_name"]
         
     | 
| 109 | 
         
            +
                    img_path = os.path.join(args.image_root, file_name)
         
     | 
| 110 | 
         
            +
                    img = Image.open(img_path).convert('RGB')
         
     | 
| 111 | 
         
            +
                    gold_boxes = [Box(x=ann["bbox"][0], y=ann["bbox"][1], w=ann["bbox"][2], h=ann["bbox"][3]) for ann in datum["anns"]]
         
     | 
| 112 | 
         
            +
                    if isinstance(datum["ann_id"], int) or isinstance(datum["ann_id"], str):
         
     | 
| 113 | 
         
            +
                        datum["ann_id"] = [datum["ann_id"]]
         
     | 
| 114 | 
         
            +
                    assert isinstance(datum["ann_id"], list)
         
     | 
| 115 | 
         
            +
                    gold_index = [i for i in range(len(datum["anns"])) if datum["anns"][i]["id"] in datum["ann_id"]] 
         
     | 
| 116 | 
         
            +
                    if args.detector_file:
         
     | 
| 117 | 
         
            +
                            boxes = [Box(x=box[0], y=box[1], w=box[2], h=box[3]) for box in detections_map[int(datum["image_id"])]]
         
     | 
| 118 | 
         
            +
                            if len(boxes) == 0:
         
     | 
| 119 | 
         
            +
                                boxes = [Box(x=0, y=0, w=img.width, h=img.height)]
         
     | 
| 120 | 
         
            +
                    else:
         
     | 
| 121 | 
         
            +
                        boxes = gold_boxes
         
     | 
| 122 | 
         
            +
                    for sentence in datum["sentences"]:
         
     | 
| 123 | 
         
            +
                        env = Environment(img, boxes, executor, (args.mdetr is not None and not args.mdetr_given_bboxes), str(datum["image_id"]), img_path) 
         
     | 
| 124 | 
         
            +
                        if args.shuffle_words:
         
     | 
| 125 | 
         
            +
                            words = sentence["raw"].lower().split()
         
     | 
| 126 | 
         
            +
                            random.shuffle(words)
         
     | 
| 127 | 
         
            +
                            result = method.execute(" ".join(words), env)
         
     | 
| 128 | 
         
            +
                        else:
         
     | 
| 129 | 
         
            +
                            result = method.execute(sentence["raw"].lower(), env)
         
     | 
| 130 | 
         
            +
                        boxes = env.boxes
         
     | 
| 131 | 
         
            +
                        print(sentence["raw"].lower())
         
     | 
| 132 | 
         
            +
                        correct = False
         
     | 
| 133 | 
         
            +
                        for g_index in gold_index:
         
     | 
| 134 | 
         
            +
                            if iou(boxes[result["pred"]], gold_boxes[g_index]) > 0.5:
         
     | 
| 135 | 
         
            +
                                correct = True
         
     | 
| 136 | 
         
            +
                                break
         
     | 
| 137 | 
         
            +
                        if correct:
         
     | 
| 138 | 
         
            +
                            result["correct"] = 1
         
     | 
| 139 | 
         
            +
                            correct_count += 1
         
     | 
| 140 | 
         
            +
                        else:
         
     | 
| 141 | 
         
            +
                            result["correct"] = 0
         
     | 
| 142 | 
         
            +
                        if args.detector_file:
         
     | 
| 143 | 
         
            +
                            argmax_ious = []
         
     | 
| 144 | 
         
            +
                            max_ious = []
         
     | 
| 145 | 
         
            +
                            for g_index in gold_index:
         
     | 
| 146 | 
         
            +
                                ious = [iou(box, gold_boxes[g_index]) for box in boxes]
         
     | 
| 147 | 
         
            +
                                argmax_iou = -1
         
     | 
| 148 | 
         
            +
                                max_iou = 0
         
     | 
| 149 | 
         
            +
                                if max(ious) >= 0.5:
         
     | 
| 150 | 
         
            +
                                    for index, value in enumerate(ious):
         
     | 
| 151 | 
         
            +
                                        if value > max_iou:
         
     | 
| 152 | 
         
            +
                                            max_iou = value
         
     | 
| 153 | 
         
            +
                                            argmax_iou = index
         
     | 
| 154 | 
         
            +
                                argmax_ious.append(argmax_iou)
         
     | 
| 155 | 
         
            +
                                max_ious.append(max_iou)
         
     | 
| 156 | 
         
            +
                            argmax_iou = -1
         
     | 
| 157 | 
         
            +
                            max_iou = 0
         
     | 
| 158 | 
         
            +
                            if max(max_ious) >= 0.5:
         
     | 
| 159 | 
         
            +
                                for index, value in zip(argmax_ious, max_ious):
         
     | 
| 160 | 
         
            +
                                    if value > max_iou:
         
     | 
| 161 | 
         
            +
                                        max_iou = value
         
     | 
| 162 | 
         
            +
                                        argmax_iou = index
         
     | 
| 163 | 
         
            +
                            result["gold_index"] = argmax_iou
         
     | 
| 164 | 
         
            +
                        else:
         
     | 
| 165 | 
         
            +
                            result["gold_index"] = gold_index
         
     | 
| 166 | 
         
            +
                        result["bboxes"] = [[box.left, box.top, box.right, box.bottom] for box in boxes]
         
     | 
| 167 | 
         
            +
                        result["file_name"] = file_name
         
     | 
| 168 | 
         
            +
                        result["probabilities"] = result["probs"]
         
     | 
| 169 | 
         
            +
                        result["text"] = sentence["raw"].lower()
         
     | 
| 170 | 
         
            +
                        if args.output_file:
         
     | 
| 171 | 
         
            +
                            # Serialize numpy arrays for JSON.
         
     | 
| 172 | 
         
            +
                            for key in result:
         
     | 
| 173 | 
         
            +
                                if isinstance(result[key], np.ndarray):
         
     | 
| 174 | 
         
            +
                                    result[key] = result[key].tolist()
         
     | 
| 175 | 
         
            +
                                if isinstance(result[key], np.int64):
         
     | 
| 176 | 
         
            +
                                    result[key] = result[key].item()
         
     | 
| 177 | 
         
            +
                            output_file.write(json.dumps(result)+"\n")
         
     | 
| 178 | 
         
            +
                        total_count += 1
         
     | 
| 179 | 
         
            +
                        print(f"est_acc: {100 * correct_count / total_count:.3f}")
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                if args.output_file:
         
     | 
| 182 | 
         
            +
                    output_file.close()
         
     | 
| 183 | 
         
            +
                print(f"acc: {100 * correct_count / total_count:.3f}")
         
     | 
| 184 | 
         
            +
                acc = 100 * correct_count / total_count
         
     | 
| 185 | 
         
            +
                
         
     | 
| 186 | 
         
            +
                result = {}
         
     | 
| 187 | 
         
            +
                result['acc'] = acc
         
     | 
| 188 | 
         
            +
                json.dump(acc, open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_acc_' + str(part)+'.json'),'w'))
         
     | 
| 189 | 
         
            +
                json.dump(str(correct_count)+' '+str(total_count), open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_count_' + str(part)+'.json'),'w'))
         
     | 
| 190 | 
         
            +
                stats = method.get_stats()
         
     | 
| 191 | 
         
            +
                if stats:
         
     | 
| 192 | 
         
            +
                    pairs = sorted(list(stats.items()), key=lambda tup: tup[0])
         
     | 
| 193 | 
         
            +
                    for key, value in pairs:
         
     | 
| 194 | 
         
            +
                        result[key] = value
         
     | 
| 195 | 
         
            +
                        if isinstance(value, float):
         
     | 
| 196 | 
         
            +
                            print(f"{key}: {value:.5f}")
         
     | 
| 197 | 
         
            +
                        else:
         
     | 
| 198 | 
         
            +
                            print(f"{key}: {value}")
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                json.dump(result, open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_' + str(part)+'.json'),'w'))
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/methods/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .baseline import Baseline
         
     | 
| 2 | 
         
            +
            from .random_method import Random
         
     | 
| 3 | 
         
            +
            from .parse import Parse
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/methods/baseline.py
    ADDED
    
    | 
         @@ -0,0 +1,57 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """A naive baseline method: just pass the full expression to CLIP."""
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from overrides import overrides
         
     | 
| 4 | 
         
            +
            from typing import Dict, Any, List
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            import spacy
         
     | 
| 8 | 
         
            +
            from argparse import Namespace
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from .ref_method import RefMethod
         
     | 
| 11 | 
         
            +
            from lattice import Product as L
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            class Baseline(RefMethod):
         
     | 
| 15 | 
         
            +
                """CLIP-only baseline where each box is evaluated with the full expression."""
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                nlp = spacy.load('en_core_web_sm')
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                def __init__(self, args: Namespace):
         
     | 
| 20 | 
         
            +
                    self.args = args
         
     | 
| 21 | 
         
            +
                    self.box_area_threshold = args.box_area_threshold
         
     | 
| 22 | 
         
            +
                    self.batch_size = args.batch_size
         
     | 
| 23 | 
         
            +
                    self.batch = []
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                @overrides
         
     | 
| 26 | 
         
            +
                def execute(self, caption: str, env: "Environment") -> Dict[str, Any]:
         
     | 
| 27 | 
         
            +
                    chunk_texts = self.get_chunk_texts(caption)
         
     | 
| 28 | 
         
            +
                    probs = env.filter(caption, area_threshold = self.box_area_threshold, softmax=True)
         
     | 
| 29 | 
         
            +
                    if self.args.baseline_head:
         
     | 
| 30 | 
         
            +
                        probs2 = env.filter(chunk_texts[0], area_threshold = self.box_area_threshold, softmax=True)
         
     | 
| 31 | 
         
            +
                        probs = L.meet(probs, probs2)
         
     | 
| 32 | 
         
            +
                    pred = np.argmax(probs)
         
     | 
| 33 | 
         
            +
                    return {
         
     | 
| 34 | 
         
            +
                        "probs": probs,
         
     | 
| 35 | 
         
            +
                        "pred": pred,
         
     | 
| 36 | 
         
            +
                        "box": env.boxes[pred],
         
     | 
| 37 | 
         
            +
                    }
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                def get_chunk_texts(self, expression: str) -> List:
         
     | 
| 40 | 
         
            +
                    doc = self.nlp(expression)
         
     | 
| 41 | 
         
            +
                    head = None
         
     | 
| 42 | 
         
            +
                    for token in doc:
         
     | 
| 43 | 
         
            +
                        if token.head.i == token.i:
         
     | 
| 44 | 
         
            +
                            head = token
         
     | 
| 45 | 
         
            +
                            break
         
     | 
| 46 | 
         
            +
                    head_chunk = None
         
     | 
| 47 | 
         
            +
                    chunk_texts = []
         
     | 
| 48 | 
         
            +
                    for chunk in doc.noun_chunks:
         
     | 
| 49 | 
         
            +
                        if head.i >= chunk.start and head.i < chunk.end:
         
     | 
| 50 | 
         
            +
                            head_chunk = chunk.text
         
     | 
| 51 | 
         
            +
                        chunk_texts.append(chunk.text)
         
     | 
| 52 | 
         
            +
                    if head_chunk is None:
         
     | 
| 53 | 
         
            +
                        if len(list(doc.noun_chunks)) > 0:
         
     | 
| 54 | 
         
            +
                            head_chunk = list(doc.noun_chunks)[0].text
         
     | 
| 55 | 
         
            +
                        else:
         
     | 
| 56 | 
         
            +
                            head_chunk = expression
         
     | 
| 57 | 
         
            +
                    return [head_chunk] + [txt for txt in chunk_texts if txt != head_chunk]
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/methods/parse.py
    ADDED
    
    | 
         @@ -0,0 +1,239 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """Use spatial relations extracted from the parses."""
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from typing import Dict, Any, Callable, List, Tuple, NamedTuple
         
     | 
| 4 | 
         
            +
            from numbers import Number
         
     | 
| 5 | 
         
            +
            from collections import defaultdict
         
     | 
| 6 | 
         
            +
            from overrides import overrides
         
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            import spacy
         
     | 
| 9 | 
         
            +
            from spacy.tokens.token import Token
         
     | 
| 10 | 
         
            +
            from spacy.tokens.span import Span
         
     | 
| 11 | 
         
            +
            from argparse import Namespace
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from .ref_method import RefMethod
         
     | 
| 14 | 
         
            +
            from lattice import Product as L
         
     | 
| 15 | 
         
            +
            from heuristics import Heuristics
         
     | 
| 16 | 
         
            +
            from entity_extraction import Entity, expand_chunks
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def get_conjunct(ent, chunks, heuristics: Heuristics) -> Entity:
         
     | 
| 20 | 
         
            +
                """If an entity represents a conjunction of two entities, pull them apart."""
         
     | 
| 21 | 
         
            +
                head = ent.head.root  # Not ...root.head. Confusing names here.
         
     | 
| 22 | 
         
            +
                if not any(child.text == "and" for child in head.children):
         
     | 
| 23 | 
         
            +
                    return None
         
     | 
| 24 | 
         
            +
                for child in head.children:
         
     | 
| 25 | 
         
            +
                    if child.i in chunks and head.i is not child.i:
         
     | 
| 26 | 
         
            +
                        return Entity.extract(child, chunks, heuristics)
         
     | 
| 27 | 
         
            +
                return None
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            class Parse(RefMethod):
         
     | 
| 31 | 
         
            +
                """An REF method that extracts and composes predicates, relations, and superlatives from a dependency parse.
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                The process is as follows:
         
     | 
| 34 | 
         
            +
                    1. Use spacy to parse the document.
         
     | 
| 35 | 
         
            +
                    2. Extract a semantic entity tree from the parse.
         
     | 
| 36 | 
         
            +
                    3. Execute the entity tree to yield a distribution over boxes."""
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                nlp = spacy.load('en_core_web_sm')
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                def __init__(self, args: Namespace = None):
         
     | 
| 41 | 
         
            +
                    self.args = args
         
     | 
| 42 | 
         
            +
                    self.box_area_threshold = args.box_area_threshold
         
     | 
| 43 | 
         
            +
                    self.baseline_threshold = args.baseline_threshold
         
     | 
| 44 | 
         
            +
                    self.temperature = args.temperature
         
     | 
| 45 | 
         
            +
                    self.superlative_head_only = args.superlative_head_only
         
     | 
| 46 | 
         
            +
                    self.expand_chunks = args.expand_chunks
         
     | 
| 47 | 
         
            +
                    self.branch = not args.parse_no_branch
         
     | 
| 48 | 
         
            +
                    self.possessive_expand = not args.possessive_no_expand
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    # Lists of keyword heuristics to use.
         
     | 
| 51 | 
         
            +
                    self.heuristics = Heuristics(args)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    # Metrics for debugging relation extraction behavor.
         
     | 
| 54 | 
         
            +
                    self.counts = defaultdict(int)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                @overrides
         
     | 
| 57 | 
         
            +
                def execute(self, caption: str, env: "Environment") -> Dict[str, Any]:
         
     | 
| 58 | 
         
            +
                    """Construct an `Entity` tree from the parse and execute it to yield a distribution over boxes."""
         
     | 
| 59 | 
         
            +
                    # Start by using the full caption, as in Baseline.
         
     | 
| 60 | 
         
            +
                    probs = env.filter(caption, area_threshold=self.box_area_threshold, softmax=True) 
         
     | 
| 61 | 
         
            +
                    ori_probs = probs
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    # Extend the baseline using parse stuff.
         
     | 
| 64 | 
         
            +
                    doc = self.nlp(caption) 
         
     | 
| 65 | 
         
            +
                    head = self.get_head(doc) 
         
     | 
| 66 | 
         
            +
                    chunks = self.get_chunks(doc) 
         
     | 
| 67 | 
         
            +
                    if self.expand_chunks:
         
     | 
| 68 | 
         
            +
                        chunks = expand_chunks(doc, chunks)
         
     | 
| 69 | 
         
            +
                    entity = Entity.extract(head, chunks, self.heuristics) 
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    # If no head noun is found, take the first one.
         
     | 
| 72 | 
         
            +
                    if entity is None and len(list(doc.noun_chunks)) > 0:
         
     | 
| 73 | 
         
            +
                        head = list(doc.noun_chunks)[0]
         
     | 
| 74 | 
         
            +
                        entity = Entity.extract(head.root.head, chunks, self.heuristics)
         
     | 
| 75 | 
         
            +
                        self.counts["n_0th_noun"] += 1
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    # If we have found some head noun, filter based on it.
         
     | 
| 78 | 
         
            +
                    if entity is not None and (any(any(token.text in h.keywords for h in self.heuristics.relations+self.heuristics.superlatives) for token in doc) or not self.branch):
         
     | 
| 79 | 
         
            +
                        ent_probs, texts = self.execute_entity(entity, env, chunks)
         
     | 
| 80 | 
         
            +
                        probs = L.meet(probs, ent_probs)
         
     | 
| 81 | 
         
            +
                    else:
         
     | 
| 82 | 
         
            +
                        texts = [caption]
         
     | 
| 83 | 
         
            +
                        self.counts["n_full_expr"] += 1
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    if len(ori_probs) == 1:
         
     | 
| 86 | 
         
            +
                        probs = ori_probs
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    self.counts["n_total"] += 1
         
     | 
| 89 | 
         
            +
                    pred = np.argmax(probs)
         
     | 
| 90 | 
         
            +
                    return {
         
     | 
| 91 | 
         
            +
                        "probs": probs,
         
     | 
| 92 | 
         
            +
                        "pred": pred,
         
     | 
| 93 | 
         
            +
                        "box": env.boxes[pred],
         
     | 
| 94 | 
         
            +
                        "texts": texts
         
     | 
| 95 | 
         
            +
                    }
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                def execute_entity(self,
         
     | 
| 98 | 
         
            +
                                   ent: Entity,
         
     | 
| 99 | 
         
            +
                                   env: "Environment",
         
     | 
| 100 | 
         
            +
                                   chunks: Dict[int, Span],
         
     | 
| 101 | 
         
            +
                                   root: bool = True,
         
     | 
| 102 | 
         
            +
                                  ) -> np.ndarray:
         
     | 
| 103 | 
         
            +
                    """Execute an `Entity` tree recursively, yielding a distribution over boxes."""
         
     | 
| 104 | 
         
            +
                    self.counts["n_rec"] += 1
         
     | 
| 105 | 
         
            +
                    probs = [1, 1]
         
     | 
| 106 | 
         
            +
                    head_probs = probs
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    # Only use relations if the head baseline isn't certain.
         
     | 
| 109 | 
         
            +
                    if len(probs) == 1 or len(env.boxes) == 1:
         
     | 
| 110 | 
         
            +
                        return probs, [ent.text]
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    m1, m2 = probs[:2] # probs[(-probs).argsort()[:2]]
         
     | 
| 113 | 
         
            +
                    text = ent.text
         
     | 
| 114 | 
         
            +
                    rel_probs = []
         
     | 
| 115 | 
         
            +
                    if self.baseline_threshold == float("inf") or m1 < self.baseline_threshold * m2:
         
     | 
| 116 | 
         
            +
                        self.counts["n_rec_rel"] += 1
         
     | 
| 117 | 
         
            +
                        for tokens, ent2 in ent.relations:
         
     | 
| 118 | 
         
            +
                            self.counts["n_rel"] += 1
         
     | 
| 119 | 
         
            +
                            rel = None
         
     | 
| 120 | 
         
            +
                            # Heuristically decide which spatial relation is represented.
         
     | 
| 121 | 
         
            +
                            for heuristic in self.heuristics.relations:
         
     | 
| 122 | 
         
            +
                                if any(tok.text in heuristic.keywords for tok in tokens):
         
     | 
| 123 | 
         
            +
                                    rel = heuristic.callback(env)
         
     | 
| 124 | 
         
            +
                                    self.counts[f"n_rel_{heuristic.keywords[0]}"] += 1
         
     | 
| 125 | 
         
            +
                                    break
         
     | 
| 126 | 
         
            +
                            # Filter and normalize by the spatial relation.
         
     | 
| 127 | 
         
            +
                            if rel is not None:
         
     | 
| 128 | 
         
            +
                                probs2 = self.execute_entity(ent2, env, chunks, root=False)
         
     | 
| 129 | 
         
            +
                                events = L.meet(np.expand_dims(probs2, axis=0), rel)
         
     | 
| 130 | 
         
            +
                                new_probs = L.join_reduce(events)
         
     | 
| 131 | 
         
            +
                                rel_probs.append((ent2.text, new_probs, probs2))
         
     | 
| 132 | 
         
            +
                                continue
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                            # This case specifically handles "between", which takes two noun arguments.
         
     | 
| 135 | 
         
            +
                            rel = None
         
     | 
| 136 | 
         
            +
                            for heuristic in self.heuristics.ternary_relations:
         
     | 
| 137 | 
         
            +
                                if any(tok.text in heuristic.keywords for tok in tokens):
         
     | 
| 138 | 
         
            +
                                    rel = heuristic.callback(env)
         
     | 
| 139 | 
         
            +
                                    self.counts[f"n_rel_{heuristic.keywords[0]}"] += 1
         
     | 
| 140 | 
         
            +
                                    break
         
     | 
| 141 | 
         
            +
                            if rel is not None:
         
     | 
| 142 | 
         
            +
                                ent3 = get_conjunct(ent2, chunks, self.heuristics)
         
     | 
| 143 | 
         
            +
                                if ent3 is not None:
         
     | 
| 144 | 
         
            +
                                    probs2 = self.execute_entity(ent2, env, chunks, root=False)
         
     | 
| 145 | 
         
            +
                                    probs2 = np.expand_dims(probs2, axis=[0, 2])
         
     | 
| 146 | 
         
            +
                                    probs3 = self.execute_entity(ent3, env, chunks, root=False)
         
     | 
| 147 | 
         
            +
                                    probs3 = np.expand_dims(probs3, axis=[0, 1])
         
     | 
| 148 | 
         
            +
                                    events = L.meet(L.meet(probs2, probs3), rel)
         
     | 
| 149 | 
         
            +
                                    new_probs = L.join_reduce(L.join_reduce(events))
         
     | 
| 150 | 
         
            +
                                    probs = L.meet(probs, new_probs)
         
     | 
| 151 | 
         
            +
                                continue
         
     | 
| 152 | 
         
            +
                            # Otherwise, treat the relation as a possessive relation.
         
     | 
| 153 | 
         
            +
                            if not self.args.no_possessive:
         
     | 
| 154 | 
         
            +
                                if self.possessive_expand:
         
     | 
| 155 | 
         
            +
                                    text = ent.expand(ent2.head)
         
     | 
| 156 | 
         
            +
                                else:
         
     | 
| 157 | 
         
            +
                                    text += f' {" ".join(tok.text for tok in tokens)} {ent2.text}'
         
     | 
| 158 | 
         
            +
                                #poss_probs = self._filter(text, env, root=root, expand=.3)
         
     | 
| 159 | 
         
            +
                        probs = self._filter(text, env, root=root)
         
     | 
| 160 | 
         
            +
                        texts = [text]
         
     | 
| 161 | 
         
            +
                        return_probs = [(probs.tolist(), probs.tolist())]
         
     | 
| 162 | 
         
            +
                        for (ent2_text, new_probs, ent2_only_probs) in rel_probs:
         
     | 
| 163 | 
         
            +
                            probs = L.meet(probs, new_probs)
         
     | 
| 164 | 
         
            +
                            probs /= probs.sum()
         
     | 
| 165 | 
         
            +
                            texts.append(ent2_text)
         
     | 
| 166 | 
         
            +
                            return_probs.append((probs.tolist(), ent2_only_probs.tolist()))
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                    # Only use superlatives if thresholds work out.
         
     | 
| 169 | 
         
            +
                    m1, m2 = probs[(-probs).argsort()[:2]]
         
     | 
| 170 | 
         
            +
                    if m1 < self.baseline_threshold * m2:
         
     | 
| 171 | 
         
            +
                        self.counts["n_rec_sup"] += 1
         
     | 
| 172 | 
         
            +
                        for tokens in ent.superlatives:
         
     | 
| 173 | 
         
            +
                            self.counts["n_sup"] += 1
         
     | 
| 174 | 
         
            +
                            sup = None
         
     | 
| 175 | 
         
            +
                            for heuristic_index, heuristic in enumerate(self.heuristics.superlatives):
         
     | 
| 176 | 
         
            +
                                if any(tok.text in heuristic.keywords for tok in tokens):
         
     | 
| 177 | 
         
            +
                                    texts.append('sup:'+' '.join([tok.text for tok in tokens if tok.text in heuristic.keywords]))
         
     | 
| 178 | 
         
            +
                                    sup = heuristic.callback(env)
         
     | 
| 179 | 
         
            +
                                    self.counts[f"n_sup_{heuristic.keywords[0]}"] += 1
         
     | 
| 180 | 
         
            +
                                    break
         
     | 
| 181 | 
         
            +
                            if sup is not None:
         
     | 
| 182 | 
         
            +
                                # Could use `probs` or `head_probs` here?
         
     | 
| 183 | 
         
            +
                                precond = head_probs if self.superlative_head_only else probs
         
     | 
| 184 | 
         
            +
                                probs = L.meet(np.expand_dims(precond, axis=1)*np.expand_dims(precond, axis=0), sup).sum(axis=1)
         
     | 
| 185 | 
         
            +
                                probs = probs / probs.sum()
         
     | 
| 186 | 
         
            +
                                return_probs.append((probs.tolist(), None))
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    if root:
         
     | 
| 189 | 
         
            +
                        assert len(texts) == len(return_probs)
         
     | 
| 190 | 
         
            +
                        return probs, (texts, return_probs, tuple(str(chunk) for chunk in chunks.values()))
         
     | 
| 191 | 
         
            +
                    return probs
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                def get_head(self, doc) -> Token:
         
     | 
| 194 | 
         
            +
                    """Return the token that is the head of the dependency parse. """
         
     | 
| 195 | 
         
            +
                    for token in doc:
         
     | 
| 196 | 
         
            +
                        if token.head.i == token.i:
         
     | 
| 197 | 
         
            +
                            return token
         
     | 
| 198 | 
         
            +
                    return None
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                def get_chunks(self, doc) -> Dict[int, Any]:
         
     | 
| 201 | 
         
            +
                    """Return a dictionary mapping sentence indices to their noun chunk."""
         
     | 
| 202 | 
         
            +
                    chunks = {}
         
     | 
| 203 | 
         
            +
                    for chunk in doc.noun_chunks:
         
     | 
| 204 | 
         
            +
                        for idx in range(chunk.start, chunk.end):
         
     | 
| 205 | 
         
            +
                            chunks[idx] = chunk
         
     | 
| 206 | 
         
            +
                    return chunks
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                @overrides
         
     | 
| 209 | 
         
            +
                def get_stats(self) -> Dict[str, Number]:
         
     | 
| 210 | 
         
            +
                    """Summary statistics that have been tracked on this object."""
         
     | 
| 211 | 
         
            +
                    stats = dict(self.counts)
         
     | 
| 212 | 
         
            +
                    n_rel_caught = sum(v for k, v in stats.items() if k.startswith("n_rel_"))
         
     | 
| 213 | 
         
            +
                    n_sup_caught = sum(v for k, v in stats.items() if k.startswith("n_sup_"))
         
     | 
| 214 | 
         
            +
                    stats.update({
         
     | 
| 215 | 
         
            +
                        "p_rel_caught": n_rel_caught / (self.counts["n_rel"] + 1e-9),
         
     | 
| 216 | 
         
            +
                        "p_sup_caught": n_sup_caught / (self.counts["n_sup"] + 1e-9),
         
     | 
| 217 | 
         
            +
                        "p_rec_rel": self.counts["n_rec_rel"] / (self.counts["n_rec"] + 1e-9),
         
     | 
| 218 | 
         
            +
                        "p_rec_sup": self.counts["n_rec_sup"] / (self.counts["n_rec"] + 1e-9),
         
     | 
| 219 | 
         
            +
                        "p_0th_noun": self.counts["n_0th_noun"] / (self.counts["n_total"] + 1e-9),
         
     | 
| 220 | 
         
            +
                        "p_full_expr": self.counts["n_full_expr"] / (self.counts["n_total"] + 1e-9),
         
     | 
| 221 | 
         
            +
                        "avg_rec": self.counts["n_rec"] / self.counts["n_total"],
         
     | 
| 222 | 
         
            +
                    })
         
     | 
| 223 | 
         
            +
                    return stats
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                def _filter(self,
         
     | 
| 226 | 
         
            +
                            caption: str,
         
     | 
| 227 | 
         
            +
                            env: "Environment",
         
     | 
| 228 | 
         
            +
                            root: bool = False,
         
     | 
| 229 | 
         
            +
                            expand: float = None,
         
     | 
| 230 | 
         
            +
                           ) -> np.ndarray:
         
     | 
| 231 | 
         
            +
                    """Wrap a filter call in a consistent way for all recursions."""
         
     | 
| 232 | 
         
            +
                    kwargs = {
         
     | 
| 233 | 
         
            +
                        "softmax": not self.args.sigmoid,
         
     | 
| 234 | 
         
            +
                        "temperature": self.args.temperature,
         
     | 
| 235 | 
         
            +
                    }
         
     | 
| 236 | 
         
            +
                    if root:
         
     | 
| 237 | 
         
            +
                        return env.filter(caption, area_threshold=self.box_area_threshold, **kwargs)
         
     | 
| 238 | 
         
            +
                    else:
         
     | 
| 239 | 
         
            +
                        return env.filter(caption, **kwargs)
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/methods/random_method.py
    ADDED
    
    | 
         @@ -0,0 +1,30 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """A naive baseline method: just pass the full expression to CLIP."""
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from overrides import overrides
         
     | 
| 4 | 
         
            +
            from typing import Dict, Any
         
     | 
| 5 | 
         
            +
            import random
         
     | 
| 6 | 
         
            +
            from argparse import Namespace
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import numpy as np
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from .ref_method import RefMethod
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class Random(RefMethod):
         
     | 
| 14 | 
         
            +
                """CLIP-only baseline where each box is evaluated with the full expression."""
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def __init__(self, args: Namespace):
         
     | 
| 17 | 
         
            +
                    self.box_area_threshold = args.box_area_threshold
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                @overrides
         
     | 
| 20 | 
         
            +
                def execute(self, caption: str, env: "Environment") -> Dict[str, Any]:
         
     | 
| 21 | 
         
            +
                    probs = env.filter_area(self.box_area_threshold)*env.uniform()
         
     | 
| 22 | 
         
            +
                    random_ordering = list(range(len(env.boxes)))
         
     | 
| 23 | 
         
            +
                    random.shuffle(random_ordering)
         
     | 
| 24 | 
         
            +
                    random_ordering = np.array(random_ordering)
         
     | 
| 25 | 
         
            +
                    pred = np.argmax(probs*random_ordering)
         
     | 
| 26 | 
         
            +
                    return {
         
     | 
| 27 | 
         
            +
                        "probs": probs.tolist(),
         
     | 
| 28 | 
         
            +
                        "pred": int(pred),
         
     | 
| 29 | 
         
            +
                        "text": caption.lower()
         
     | 
| 30 | 
         
            +
                    }
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/methods/ref_method.py
    ADDED
    
    | 
         @@ -0,0 +1,13 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """Base class for a method for doing referring expressions."""
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from typing import Dict, Any
         
     | 
| 4 | 
         
            +
            from abc import ABCMeta, abstractmethod
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            class RefMethod(metaclass=ABCMeta):
         
     | 
| 8 | 
         
            +
                @abstractmethod
         
     | 
| 9 | 
         
            +
                def execute(self, caption: str, env: "Environment") -> Dict[str, Any]:
         
     | 
| 10 | 
         
            +
                    return NotImplemented
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                def get_stats(self) -> Dict[str, Any]:
         
     | 
| 13 | 
         
            +
                    return {}
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/output/.gitkeep
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,53 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            attrs==21.2.0
         
     | 
| 2 | 
         
            +
            blis==0.7.4
         
     | 
| 3 | 
         
            +
            catalogue==2.0.4
         
     | 
| 4 | 
         
            +
            certifi==2021.5.30
         
     | 
| 5 | 
         
            +
            chardet==4.0.0
         
     | 
| 6 | 
         
            +
            click==7.1.2
         
     | 
| 7 | 
         
            +
            cymem==2.0.5
         
     | 
| 8 | 
         
            +
            en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0-py3-none-any.whl
         
     | 
| 9 | 
         
            +
            filelock==3.0.12
         
     | 
| 10 | 
         
            +
            ftfy==6.0.3
         
     | 
| 11 | 
         
            +
            huggingface-hub==0.0.12
         
     | 
| 12 | 
         
            +
            idna==2.10
         
     | 
| 13 | 
         
            +
            iniconfig==1.1.1
         
     | 
| 14 | 
         
            +
            itsdangerous==2.0.1
         
     | 
| 15 | 
         
            +
            joblib==1.0.1
         
     | 
| 16 | 
         
            +
            MarkupSafe==2.0.1
         
     | 
| 17 | 
         
            +
            murmurhash==1.0.5
         
     | 
| 18 | 
         
            +
            numpy==1.21.0
         
     | 
| 19 | 
         
            +
            overrides==6.1.0
         
     | 
| 20 | 
         
            +
            packaging==21.0
         
     | 
| 21 | 
         
            +
            pathy==0.6.0
         
     | 
| 22 | 
         
            +
            Pillow==8.2.0
         
     | 
| 23 | 
         
            +
            pluggy==0.13.1
         
     | 
| 24 | 
         
            +
            preshed==3.0.5
         
     | 
| 25 | 
         
            +
            py==1.10.0
         
     | 
| 26 | 
         
            +
            pydantic==1.7.4
         
     | 
| 27 | 
         
            +
            pyparsing==2.4.7
         
     | 
| 28 | 
         
            +
            pytest==6.2.4
         
     | 
| 29 | 
         
            +
            PyYAML==5.4.1
         
     | 
| 30 | 
         
            +
            regex==2021.7.6
         
     | 
| 31 | 
         
            +
            requests==2.25.1
         
     | 
| 32 | 
         
            +
            ruamel.yaml==0.17.10
         
     | 
| 33 | 
         
            +
            ruamel.yaml.clib==0.2.6
         
     | 
| 34 | 
         
            +
            sacremoses==0.0.45
         
     | 
| 35 | 
         
            +
            scipy==1.7.0
         
     | 
| 36 | 
         
            +
            six==1.16.0
         
     | 
| 37 | 
         
            +
            smart-open==5.1.0
         
     | 
| 38 | 
         
            +
            spacy==3.0.6
         
     | 
| 39 | 
         
            +
            spacy-legacy==3.0.7
         
     | 
| 40 | 
         
            +
            srsly==2.4.1
         
     | 
| 41 | 
         
            +
            thinc==8.0.7
         
     | 
| 42 | 
         
            +
            timm==0.4.12
         
     | 
| 43 | 
         
            +
            tokenizers==0.10.3
         
     | 
| 44 | 
         
            +
            toml==0.10.2
         
     | 
| 45 | 
         
            +
            tqdm==4.61.2
         
     | 
| 46 | 
         
            +
            transformers==4.9.0
         
     | 
| 47 | 
         
            +
            typer==0.3.2
         
     | 
| 48 | 
         
            +
            typing-extensions==3.10.0.0
         
     | 
| 49 | 
         
            +
            typing-utils==0.1.0
         
     | 
| 50 | 
         
            +
            urllib3==1.26.6
         
     | 
| 51 | 
         
            +
            wasabi==0.8.2
         
     | 
| 52 | 
         
            +
            wcwidth==0.2.5
         
     | 
| 53 | 
         
            +
            Werkzeug==2.0.1
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/run.sh
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            CUDA_VISIBLE_DEVICES=0 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_representation_method full,blur --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache
         
     | 
    	
        AlphaCLIP/eval/rec_zs_test/run_multi_gpus.sh
    ADDED
    
    | 
         @@ -0,0 +1,15 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            CUDA_VISIBLE_DEVICES=0 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,0" &
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            CUDA_VISIBLE_DEVICES=1 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,1" &
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            CUDA_VISIBLE_DEVICES=2 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,2" &
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            CUDA_VISIBLE_DEVICES=3 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,3" &
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            CUDA_VISIBLE_DEVICES=4 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,4" &
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            CUDA_VISIBLE_DEVICES=5 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,5" &
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            CUDA_VISIBLE_DEVICES=6 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,6" &
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            CUDA_VISIBLE_DEVICES=7 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,7" 
         
     | 
    	
        AlphaCLIP/hubconf.py
    ADDED
    
    | 
         @@ -0,0 +1,42 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from alpha_clip.alpha_clip import tokenize as _tokenize, load as _load, available_models as _available_models
         
     | 
| 2 | 
         
            +
            import re
         
     | 
| 3 | 
         
            +
            import string
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            dependencies = ["torch", "torchvision", "ftfy", "regex", "tqdm"]
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            # For compatibility (cannot include special characters in function name)
         
     | 
| 8 | 
         
            +
            model_functions = { model: re.sub(f'[{string.punctuation}]', '_', model) for model in _available_models()}
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            def _create_hub_entrypoint(model):
         
     | 
| 11 | 
         
            +
                def entrypoint(**kwargs):      
         
     | 
| 12 | 
         
            +
                    return _load(model, **kwargs)
         
     | 
| 13 | 
         
            +
                
         
     | 
| 14 | 
         
            +
                entrypoint.__doc__ = f"""Loads the {model} CLIP model
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                    Parameters
         
     | 
| 17 | 
         
            +
                    ----------
         
     | 
| 18 | 
         
            +
                    device : Union[str, torch.device]
         
     | 
| 19 | 
         
            +
                        The device to put the loaded model
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                    jit : bool
         
     | 
| 22 | 
         
            +
                        Whether to load the optimized JIT model or more hackable non-JIT model (default).
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    download_root: str
         
     | 
| 25 | 
         
            +
                        path to download the model files; by default, it uses "~/.cache/clip"
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    Returns
         
     | 
| 28 | 
         
            +
                    -------
         
     | 
| 29 | 
         
            +
                    model : torch.nn.Module
         
     | 
| 30 | 
         
            +
                        The {model} CLIP model
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                    preprocess : Callable[[PIL.Image], torch.Tensor]
         
     | 
| 33 | 
         
            +
                        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
         
     | 
| 34 | 
         
            +
                    """
         
     | 
| 35 | 
         
            +
                return entrypoint
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            def tokenize():
         
     | 
| 38 | 
         
            +
                return _tokenize
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            _entrypoints = {model_functions[model]: _create_hub_entrypoint(model) for model in _available_models()}
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            globals().update(_entrypoints)
         
     | 
    	
        AlphaCLIP/requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,5 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ftfy
         
     | 
| 2 | 
         
            +
            regex
         
     | 
| 3 | 
         
            +
            tqdm
         
     | 
| 4 | 
         
            +
            torch
         
     | 
| 5 | 
         
            +
            torchvision
         
     | 
    	
        AlphaCLIP/setup.py
    ADDED
    
    | 
         @@ -0,0 +1,21 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import pkg_resources
         
     | 
| 4 | 
         
            +
            from setuptools import setup, find_packages
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            setup(
         
     | 
| 7 | 
         
            +
                name="alpha_clip",
         
     | 
| 8 | 
         
            +
                py_modules=["alpha_clip"],
         
     | 
| 9 | 
         
            +
                version="1.0",
         
     | 
| 10 | 
         
            +
                description="",
         
     | 
| 11 | 
         
            +
                author="OpenAI&ZeyiSun",
         
     | 
| 12 | 
         
            +
                packages=find_packages(exclude=["tests*"]),
         
     | 
| 13 | 
         
            +
                install_requires=[
         
     | 
| 14 | 
         
            +
                    str(r)
         
     | 
| 15 | 
         
            +
                    for r in pkg_resources.parse_requirements(
         
     | 
| 16 | 
         
            +
                        open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
         
     | 
| 17 | 
         
            +
                    )
         
     | 
| 18 | 
         
            +
                ],
         
     | 
| 19 | 
         
            +
                include_package_data=True,
         
     | 
| 20 | 
         
            +
                extras_require={'dev': ['pytest']},
         
     | 
| 21 | 
         
            +
            )
         
     | 
    	
        README.md
    CHANGED
    
    | 
         @@ -4,7 +4,7 @@ emoji: 🏢 
     | 
|
| 4 | 
         
             
            colorFrom: green
         
     | 
| 5 | 
         
             
            colorTo: red
         
     | 
| 6 | 
         
             
            sdk: gradio
         
     | 
| 7 | 
         
            -
            sdk_version:  
     | 
| 8 | 
         
             
            app_file: app.py
         
     | 
| 9 | 
         
             
            pinned: false
         
     | 
| 10 | 
         
             
            license: mit
         
     | 
| 
         | 
|
| 4 | 
         
             
            colorFrom: green
         
     | 
| 5 | 
         
             
            colorTo: red
         
     | 
| 6 | 
         
             
            sdk: gradio
         
     | 
| 7 | 
         
            +
            sdk_version: 3.48.0
         
     | 
| 8 | 
         
             
            app_file: app.py
         
     | 
| 9 | 
         
             
            pinned: false
         
     | 
| 10 | 
         
             
            license: mit
         
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,113 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import gradio as gr
         
     | 
| 2 | 
         
            +
            import sys
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            from omegaconf import OmegaConf
         
     | 
| 5 | 
         
            +
            from PIL import Image
         
     | 
| 6 | 
         
            +
            from diffusers import StableDiffusionInpaintPipeline
         
     | 
| 7 | 
         
            +
            from model.clip_away import CLIPAway
         
     | 
| 8 | 
         
            +
            import cv2
         
     | 
| 9 | 
         
            +
            import numpy as np
         
     | 
| 10 | 
         
            +
            import argparse
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            # Parse command line arguments
         
     | 
| 13 | 
         
            +
            parser = argparse.ArgumentParser()
         
     | 
| 14 | 
         
            +
            parser.add_argument("--config", type=str, default="config/inference_config.yaml", help="Path to the config file")
         
     | 
| 15 | 
         
            +
            parser.add_argument("--share", action="store_true", help="Share the interface if provided")
         
     | 
| 16 | 
         
            +
            args = parser.parse_args()
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            # Load configuration and models
         
     | 
| 19 | 
         
            +
            config = OmegaConf.load(args.config)
         
     | 
| 20 | 
         
            +
            sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
         
     | 
| 21 | 
         
            +
                "runwayml/stable-diffusion-inpainting", safety_checker=None, torch_dtype=torch.float32
         
     | 
| 22 | 
         
            +
            )
         
     | 
| 23 | 
         
            +
            clipaway = CLIPAway(
         
     | 
| 24 | 
         
            +
                sd_pipe=sd_pipeline, 
         
     | 
| 25 | 
         
            +
                image_encoder_path=config.image_encoder_path,
         
     | 
| 26 | 
         
            +
                ip_ckpt=config.ip_adapter_ckpt_path, 
         
     | 
| 27 | 
         
            +
                alpha_clip_path=config.alpha_clip_ckpt_pth, 
         
     | 
| 28 | 
         
            +
                config=config, 
         
     | 
| 29 | 
         
            +
                alpha_clip_id=config.alpha_clip_id, 
         
     | 
| 30 | 
         
            +
                device=config.device, 
         
     | 
| 31 | 
         
            +
                num_tokens=4
         
     | 
| 32 | 
         
            +
            )
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            def dilate_mask(mask, kernel_size=5, iterations=5):
         
     | 
| 35 | 
         
            +
                mask = mask.convert("L")
         
     | 
| 36 | 
         
            +
                kernel = np.ones((kernel_size, kernel_size), np.uint8)
         
     | 
| 37 | 
         
            +
                mask = cv2.dilate(np.array(mask), kernel, iterations=iterations)
         
     | 
| 38 | 
         
            +
                return Image.fromarray(mask)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            def combine_masks(uploaded_mask, sketched_mask):
         
     | 
| 41 | 
         
            +
                if uploaded_mask is not None:
         
     | 
| 42 | 
         
            +
                    return uploaded_mask
         
     | 
| 43 | 
         
            +
                elif sketched_mask is not None:
         
     | 
| 44 | 
         
            +
                    return sketched_mask
         
     | 
| 45 | 
         
            +
                else:
         
     | 
| 46 | 
         
            +
                    raise ValueError("Please provide a mask")
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            def remove_obj(image, uploaded_mask, seed):
         
     | 
| 49 | 
         
            +
                image_pil, sketched_mask = image["image"], image["mask"]
         
     | 
| 50 | 
         
            +
                mask = dilate_mask(combine_masks(uploaded_mask, sketched_mask))
         
     | 
| 51 | 
         
            +
                seed = int(seed)
         
     | 
| 52 | 
         
            +
                latents = torch.randn((1, 4, 64, 64), generator=torch.Generator().manual_seed(seed)).to("cuda")
         
     | 
| 53 | 
         
            +
                final_image = clipaway.generate(
         
     | 
| 54 | 
         
            +
                    prompt=[""], scale=1, seed=seed,
         
     | 
| 55 | 
         
            +
                    pil_image=[image_pil], alpha=[mask], strength=1, latents=latents
         
     | 
| 56 | 
         
            +
                )[0]
         
     | 
| 57 | 
         
            +
                return final_image
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            # Define example data
         
     | 
| 60 | 
         
            +
            examples = [
         
     | 
| 61 | 
         
            +
                ["assets/gradio_examples/images/1.jpg", "assets/gradio_examples/masks/1.png", 42],
         
     | 
| 62 | 
         
            +
                ["assets/gradio_examples/images/2.jpg", "assets/gradio_examples/masks/2.png", 42],
         
     | 
| 63 | 
         
            +
                ["assets/gradio_examples/images/3.jpg", "assets/gradio_examples/masks/3.png", 464],
         
     | 
| 64 | 
         
            +
                ["assets/gradio_examples/images/4.jpg", "assets/gradio_examples/masks/4.png", 2024],
         
     | 
| 65 | 
         
            +
            ]
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            # Define the Gradio interface
         
     | 
| 68 | 
         
            +
            with gr.Blocks() as demo:
         
     | 
| 69 | 
         
            +
                gr.Markdown("<h1 style='text-align:center'>CLIPAway: Harmonizing Focused Embeddings for Removing Objects via Diffusion Models</h1>")
         
     | 
| 70 | 
         
            +
                gr.Markdown("""
         
     | 
| 71 | 
         
            +
                    <div style='display:flex; justify-content:center; align-items:center;'>
         
     | 
| 72 | 
         
            +
                        <a href='https://arxiv.org/abs/2406.09368' style="margin:10px;">Paper</a> |
         
     | 
| 73 | 
         
            +
                        <a href='https://yigitekin.github.io/CLIPAway/' style="margin:10px;">Project Website</a> |
         
     | 
| 74 | 
         
            +
                        <a href='https://github.com/YigitEkin/CLIPAway' style="margin:10px;">GitHub</a>
         
     | 
| 75 | 
         
            +
                    </div>
         
     | 
| 76 | 
         
            +
                """)
         
     | 
| 77 | 
         
            +
                gr.Markdown("""
         
     | 
| 78 | 
         
            +
                        This application allows you to remove objects from images using the CLIPAway method with diffusion models.
         
     | 
| 79 | 
         
            +
                        To use this tool:
         
     | 
| 80 | 
         
            +
                        1. Upload an image.
         
     | 
| 81 | 
         
            +
                        2. Either Sketch a mask over the object you want to remove or upload a pre-defined mask if you have one.
         
     | 
| 82 | 
         
            +
                        4. Set the seed for reproducibility (default is 42).
         
     | 
| 83 | 
         
            +
                        5. Click 'Remove Object' to process the image.
         
     | 
| 84 | 
         
            +
                        6. The result will be displayed on the right side.
         
     | 
| 85 | 
         
            +
                        Note: The mask should be a binary image where the object to be removed is white and the background is black.
         
     | 
| 86 | 
         
            +
                """)
         
     | 
| 87 | 
         
            +
                
         
     | 
| 88 | 
         
            +
                with gr.Row():
         
     | 
| 89 | 
         
            +
                    with gr.Column():
         
     | 
| 90 | 
         
            +
                        image_input = gr.Image(label="Upload Image and Sketch Mask", type="pil", tool="sketch")
         
     | 
| 91 | 
         
            +
                        uploaded_mask = gr.Image(label="Upload Mask (Optional)", type="pil", optional=True)
         
     | 
| 92 | 
         
            +
                        seed_input = gr.Number(value=42, label="Seed")
         
     | 
| 93 | 
         
            +
                        process_button = gr.Button("Remove Object")
         
     | 
| 94 | 
         
            +
                    with gr.Column():
         
     | 
| 95 | 
         
            +
                        result_image = gr.Image(label="Result")
         
     | 
| 96 | 
         
            +
                
         
     | 
| 97 | 
         
            +
                process_button.click(
         
     | 
| 98 | 
         
            +
                    fn=remove_obj,
         
     | 
| 99 | 
         
            +
                    inputs=[image_input, uploaded_mask, seed_input],
         
     | 
| 100 | 
         
            +
                    outputs=result_image
         
     | 
| 101 | 
         
            +
                )
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                gr.Examples(
         
     | 
| 104 | 
         
            +
                    examples=examples,
         
     | 
| 105 | 
         
            +
                    inputs=[image_input, uploaded_mask, seed_input],
         
     | 
| 106 | 
         
            +
                    outputs=result_image
         
     | 
| 107 | 
         
            +
                )
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
            # Launch the interface with caching
         
     | 
| 110 | 
         
            +
            if args.share:
         
     | 
| 111 | 
         
            +
                demo.launch(share=True)
         
     | 
| 112 | 
         
            +
            else:
         
     | 
| 113 | 
         
            +
                demo.launch()
         
     | 
    	
        clip_l14_grit+mim_fultune_6xe.pth
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:a5f3f2e24459e9764d9f4b4c053fb354dc9d508bd8f647b952402d6860bc9c3d
         
     | 
| 3 | 
         
            +
            size 1216760175
         
     | 
    	
        config/inference_config.yaml
    ADDED
    
    | 
         @@ -0,0 +1,16 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            device: "cuda"
         
     | 
| 2 | 
         
            +
            root_path: assets/gradio_examples 
         
     | 
| 3 | 
         
            +
            image_encoder_path: image_encoder
         
     | 
| 4 | 
         
            +
            alpha_clip_ckpt_pth: clip_l14_grit+mim_fultune_6xe.pth
         
     | 
| 5 | 
         
            +
            alpha_clip_id: ViT-L/14
         
     | 
| 6 | 
         
            +
            ip_adapter_ckpt_path: ip-adapter_sd15.bin
         
     | 
| 7 | 
         
            +
            sd_model_key: "runwayml/stable-diffusion-inpainting"
         
     | 
| 8 | 
         
            +
            number_of_hidden_layers: 6
         
     | 
| 9 | 
         
            +
            alpha_clip_embed_dim: 768
         
     | 
| 10 | 
         
            +
            ip_adapter_embed_dim: 1024
         
     | 
| 11 | 
         
            +
            mlp_projection_layer_ckpt_path: model.safetensors
         
     | 
| 12 | 
         
            +
            save_path_prefix: test/results
         
     | 
| 13 | 
         
            +
            seed: 42
         
     | 
| 14 | 
         
            +
            scale: 1
         
     | 
| 15 | 
         
            +
            strength: 1
         
     | 
| 16 | 
         
            +
            display_focused_embeds: True
         
     | 
    	
        image_encoder/config.json
    ADDED
    
    | 
         @@ -0,0 +1,23 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "_name_or_path": "./image_encoder",
         
     | 
| 3 | 
         
            +
              "architectures": [
         
     | 
| 4 | 
         
            +
                "CLIPVisionModelWithProjection"
         
     | 
| 5 | 
         
            +
              ],
         
     | 
| 6 | 
         
            +
              "attention_dropout": 0.0,
         
     | 
| 7 | 
         
            +
              "dropout": 0.0,
         
     | 
| 8 | 
         
            +
              "hidden_act": "gelu",
         
     | 
| 9 | 
         
            +
              "hidden_size": 1280,
         
     | 
| 10 | 
         
            +
              "image_size": 224,
         
     | 
| 11 | 
         
            +
              "initializer_factor": 1.0,
         
     | 
| 12 | 
         
            +
              "initializer_range": 0.02,
         
     | 
| 13 | 
         
            +
              "intermediate_size": 5120,
         
     | 
| 14 | 
         
            +
              "layer_norm_eps": 1e-05,
         
     | 
| 15 | 
         
            +
              "model_type": "clip_vision_model",
         
     | 
| 16 | 
         
            +
              "num_attention_heads": 16,
         
     | 
| 17 | 
         
            +
              "num_channels": 3,
         
     | 
| 18 | 
         
            +
              "num_hidden_layers": 32,
         
     | 
| 19 | 
         
            +
              "patch_size": 14,
         
     | 
| 20 | 
         
            +
              "projection_dim": 1024,
         
     | 
| 21 | 
         
            +
              "torch_dtype": "float16",
         
     | 
| 22 | 
         
            +
              "transformers_version": "4.28.0.dev0"
         
     | 
| 23 | 
         
            +
            }
         
     | 
    	
        image_encoder/pytorch_model.bin
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:3d3ec1e66737f77a4f3bc2df3c52eacefc69ce7825e2784183b1d4e9877d9193
         
     | 
| 3 | 
         
            +
            size 2528481905
         
     | 
    	
        ip-adapter_sd15.bin
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:68e1df30d760f280e578c302f1e73b37ea08654eff16a31153588047affe0058
         
     | 
| 3 | 
         
            +
            size 44642825
         
     | 
    	
        model.safetensors
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:ade94c0505170a7698afe8ad4b4fb2307d06f67917b877cf1fd694a43cd6e335
         
     | 
| 3 | 
         
            +
            size 22877152
         
     | 
    	
        model/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,5 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .clip_away import CLIPAway
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            __all__ = [
         
     | 
| 4 | 
         
            +
                "CLIPAway"
         
     | 
| 5 | 
         
            +
            ]
         
     | 
    	
        model/attention_processor.py
    ADDED
    
    | 
         @@ -0,0 +1,189 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            taken from https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/attention_processor.py
         
     | 
| 3 | 
         
            +
            """
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 7 | 
         
            +
            from einops import rearrange, repeat
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            class AttnProcessor(nn.Module):
         
     | 
| 11 | 
         
            +
                r"""
         
     | 
| 12 | 
         
            +
                Default processor for performing attention-related computations.
         
     | 
| 13 | 
         
            +
                """
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                def __init__(
         
     | 
| 16 | 
         
            +
                    self,
         
     | 
| 17 | 
         
            +
                    hidden_size=None,
         
     | 
| 18 | 
         
            +
                    cross_attention_dim=None,
         
     | 
| 19 | 
         
            +
                ):
         
     | 
| 20 | 
         
            +
                    super().__init__()
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                def __call__(
         
     | 
| 23 | 
         
            +
                    self,
         
     | 
| 24 | 
         
            +
                    attn,
         
     | 
| 25 | 
         
            +
                    hidden_states,
         
     | 
| 26 | 
         
            +
                    encoder_hidden_states=None,
         
     | 
| 27 | 
         
            +
                    attention_mask=None,
         
     | 
| 28 | 
         
            +
                    temb=None,
         
     | 
| 29 | 
         
            +
                ):
         
     | 
| 30 | 
         
            +
                    residual = hidden_states
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                    if attn.spatial_norm is not None:
         
     | 
| 33 | 
         
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                    input_ndim = hidden_states.ndim
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 38 | 
         
            +
                        batch_size, channel, height, width = hidden_states.shape
         
     | 
| 39 | 
         
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    batch_size, sequence_length, _ = (
         
     | 
| 42 | 
         
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         
     | 
| 43 | 
         
            +
                    )
         
     | 
| 44 | 
         
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    if attn.group_norm is not None:
         
     | 
| 47 | 
         
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    query = attn.to_q(hidden_states)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 52 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 53 | 
         
            +
                    elif attn.norm_cross:
         
     | 
| 54 | 
         
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    key = attn.to_k(encoder_hidden_states)
         
     | 
| 57 | 
         
            +
                    value = attn.to_v(encoder_hidden_states)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    query = attn.head_to_batch_dim(query)
         
     | 
| 60 | 
         
            +
                    key = attn.head_to_batch_dim(key)
         
     | 
| 61 | 
         
            +
                    value = attn.head_to_batch_dim(value)
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    attention_probs = attn.get_attention_scores(query, key, attention_mask)
         
     | 
| 64 | 
         
            +
                    hidden_states = torch.bmm(attention_probs, value)
         
     | 
| 65 | 
         
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                    # linear proj
         
     | 
| 68 | 
         
            +
                    hidden_states = attn.to_out[0](hidden_states)
         
     | 
| 69 | 
         
            +
                    # dropout
         
     | 
| 70 | 
         
            +
                    hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 73 | 
         
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    if attn.residual_connection:
         
     | 
| 76 | 
         
            +
                        hidden_states = hidden_states + residual
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    return hidden_states
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            class IPAttnProcessor(nn.Module):
         
     | 
| 84 | 
         
            +
                r"""
         
     | 
| 85 | 
         
            +
                Attention processor for IP-Adapater.
         
     | 
| 86 | 
         
            +
                Args:
         
     | 
| 87 | 
         
            +
                    hidden_size (`int`):
         
     | 
| 88 | 
         
            +
                        The hidden size of the attention layer.
         
     | 
| 89 | 
         
            +
                    cross_attention_dim (`int`):
         
     | 
| 90 | 
         
            +
                        The number of channels in the `encoder_hidden_states`.
         
     | 
| 91 | 
         
            +
                    scale (`float`, defaults to 1.0):
         
     | 
| 92 | 
         
            +
                        the weight scale of image prompt.
         
     | 
| 93 | 
         
            +
                    num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
         
     | 
| 94 | 
         
            +
                        The context length of the image features.
         
     | 
| 95 | 
         
            +
                """
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
         
     | 
| 98 | 
         
            +
                    super().__init__()
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 101 | 
         
            +
                    self.cross_attention_dim = cross_attention_dim
         
     | 
| 102 | 
         
            +
                    self.scale = scale
         
     | 
| 103 | 
         
            +
                    self.num_tokens = num_tokens
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
         
     | 
| 106 | 
         
            +
                    self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
         
     | 
| 107 | 
         
            +
                
         
     | 
| 108 | 
         
            +
                def __call__(
         
     | 
| 109 | 
         
            +
                    self,
         
     | 
| 110 | 
         
            +
                    attn,
         
     | 
| 111 | 
         
            +
                    hidden_states,
         
     | 
| 112 | 
         
            +
                    encoder_hidden_states=None,
         
     | 
| 113 | 
         
            +
                    attention_mask=None,
         
     | 
| 114 | 
         
            +
                    temb=None,
         
     | 
| 115 | 
         
            +
                ):
         
     | 
| 116 | 
         
            +
                    residual = hidden_states
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                    if attn.spatial_norm is not None:
         
     | 
| 119 | 
         
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    input_ndim = hidden_states.ndim
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 124 | 
         
            +
                        batch_size, channel, height, width = hidden_states.shape
         
     | 
| 125 | 
         
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    batch_size, sequence_length, _ = (
         
     | 
| 128 | 
         
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         
     | 
| 129 | 
         
            +
                    )
         
     | 
| 130 | 
         
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    if attn.group_norm is not None:
         
     | 
| 133 | 
         
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    query = attn.to_q(hidden_states)
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 138 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 139 | 
         
            +
                    else:
         
     | 
| 140 | 
         
            +
                        # get encoder_hidden_states, ip_hidden_states
         
     | 
| 141 | 
         
            +
                        end_pos = encoder_hidden_states.shape[1] - self.num_tokens
         
     | 
| 142 | 
         
            +
                        encoder_hidden_states, ip_hidden_states = (
         
     | 
| 143 | 
         
            +
                            encoder_hidden_states[:, :end_pos, :],
         
     | 
| 144 | 
         
            +
                            encoder_hidden_states[:, end_pos:, :],
         
     | 
| 145 | 
         
            +
                        )
         
     | 
| 146 | 
         
            +
                        if attn.norm_cross:
         
     | 
| 147 | 
         
            +
                            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    key = attn.to_k(encoder_hidden_states)
         
     | 
| 150 | 
         
            +
                    value = attn.to_v(encoder_hidden_states)
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    query = attn.head_to_batch_dim(query)
         
     | 
| 153 | 
         
            +
                    key = attn.head_to_batch_dim(key)
         
     | 
| 154 | 
         
            +
                    value = attn.head_to_batch_dim(value)
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                    attention_probs = attn.get_attention_scores(query, key, attention_mask)
         
     | 
| 157 | 
         
            +
                    #!MASK HERE
         
     | 
| 158 | 
         
            +
                    hidden_states = torch.bmm(attention_probs, value)
         
     | 
| 159 | 
         
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                    # for ip-adapter
         
     | 
| 162 | 
         
            +
                    ip_key = self.to_k_ip(ip_hidden_states)
         
     | 
| 163 | 
         
            +
                    ip_value = self.to_v_ip(ip_hidden_states)
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                    ip_key = attn.head_to_batch_dim(ip_key)
         
     | 
| 166 | 
         
            +
                    ip_value = attn.head_to_batch_dim(ip_value)
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                    ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
         
     | 
| 169 | 
         
            +
                    #!MASK HERE
         
     | 
| 170 | 
         
            +
                    self.attn_map = ip_attention_probs
         
     | 
| 171 | 
         
            +
                    ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
         
     | 
| 172 | 
         
            +
                    ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                    hidden_states = hidden_states + self.scale * ip_hidden_states
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                    # linear proj
         
     | 
| 177 | 
         
            +
                    hidden_states = attn.to_out[0](hidden_states)
         
     | 
| 178 | 
         
            +
                    # dropout
         
     | 
| 179 | 
         
            +
                    hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 182 | 
         
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                    if attn.residual_connection:
         
     | 
| 185 | 
         
            +
                        hidden_states = hidden_states + residual
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    return hidden_states
         
     | 
    	
        model/clip_away.py
    ADDED
    
    | 
         @@ -0,0 +1,280 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            modified from from https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py
         
     | 
| 3 | 
         
            +
            """
         
     | 
| 4 | 
         
            +
            import os
         
     | 
| 5 | 
         
            +
            from typing import List
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            from PIL import Image
         
     | 
| 8 | 
         
            +
            from torchvision import transforms
         
     | 
| 9 | 
         
            +
            from transformers import CLIPVisionModelWithProjection
         
     | 
| 10 | 
         
            +
            import alpha_clip
         
     | 
| 11 | 
         
            +
            from .utils import get_generator
         
     | 
| 12 | 
         
            +
            from .attention_processor import AttnProcessor, IPAttnProcessor
         
     | 
| 13 | 
         
            +
            from safetensors import safe_open
         
     | 
| 14 | 
         
            +
            from safetensors.torch import load_model
         
     | 
| 15 | 
         
            +
            import numpy as np
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import torch.nn as nn
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            class ImageProjModel(torch.nn.Module):
         
     | 
| 21 | 
         
            +
                """Projection Model of IP-Adapter"""
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
         
     | 
| 24 | 
         
            +
                    super().__init__()
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    self.generator = None
         
     | 
| 27 | 
         
            +
                    self.cross_attention_dim = cross_attention_dim
         
     | 
| 28 | 
         
            +
                    self.clip_extra_context_tokens = clip_extra_context_tokens
         
     | 
| 29 | 
         
            +
                    self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
         
     | 
| 30 | 
         
            +
                    self.norm = torch.nn.LayerNorm(cross_attention_dim)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                def forward(self, image_embeds):
         
     | 
| 33 | 
         
            +
                    embeds = image_embeds
         
     | 
| 34 | 
         
            +
                    clip_extra_context_tokens = self.proj(embeds).reshape(
         
     | 
| 35 | 
         
            +
                        -1, self.clip_extra_context_tokens, self.cross_attention_dim
         
     | 
| 36 | 
         
            +
                    )
         
     | 
| 37 | 
         
            +
                    clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
         
     | 
| 38 | 
         
            +
                    return clip_extra_context_tokens
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            class CLIPAway:
         
     | 
| 41 | 
         
            +
                def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, alpha_clip_path, config, alpha_clip_id="ViT-L/14", device="cuda", num_tokens=4):
         
     | 
| 42 | 
         
            +
                    super().__init__()
         
     | 
| 43 | 
         
            +
                    self.device = device
         
     | 
| 44 | 
         
            +
                    self.ipadapter_image_encoder_path = image_encoder_path
         
     | 
| 45 | 
         
            +
                    self.ipadapter_ckpt = ip_ckpt
         
     | 
| 46 | 
         
            +
                    self.num_tokens = num_tokens
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    self.pipe = sd_pipe.to(self.device)
         
     | 
| 49 | 
         
            +
                    self.set_ip_adapter()
         
     | 
| 50 | 
         
            +
                    alpha_clip_model, alpha_clip_preprocess = alpha_clip.load(alpha_clip_id, alpha_vision_ckpt_pth=alpha_clip_path, device=device)
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    # load image encoder
         
     | 
| 53 | 
         
            +
                    self.image_encoder = alpha_clip_model.visual.to(self.device, dtype=torch.float32)
         
     | 
| 54 | 
         
            +
                    
         
     | 
| 55 | 
         
            +
                    self.clip_proj = CLIPVisionModelWithProjection.from_pretrained(self.ipadapter_image_encoder_path).to(
         
     | 
| 56 | 
         
            +
                        self.device, dtype=torch.float32
         
     | 
| 57 | 
         
            +
                    )
         
     | 
| 58 | 
         
            +
                    self.alpha_clip_image_processor = alpha_clip_preprocess
         
     | 
| 59 | 
         
            +
                    
         
     | 
| 60 | 
         
            +
                    # preprocess mask transformation for alpha clip
         
     | 
| 61 | 
         
            +
                    if "@336" in alpha_clip_id:
         
     | 
| 62 | 
         
            +
                        self.mask_transform = transforms.Compose([
         
     | 
| 63 | 
         
            +
                            transforms.ToTensor(),
         
     | 
| 64 | 
         
            +
                            transforms.Resize((336, 336)), # change to (336,336) when using ViT-L/14@336px
         
     | 
| 65 | 
         
            +
                            transforms.Normalize(0.5, 0.26)
         
     | 
| 66 | 
         
            +
                        ])
         
     | 
| 67 | 
         
            +
                    else:
         
     | 
| 68 | 
         
            +
                        self.mask_transform = transforms.Compose([
         
     | 
| 69 | 
         
            +
                            transforms.ToTensor(),
         
     | 
| 70 | 
         
            +
                            transforms.Resize((224, 224)), # change to (336,336) when using ViT-L/14@336px
         
     | 
| 71 | 
         
            +
                            transforms.Normalize(0.5, 0.26)
         
     | 
| 72 | 
         
            +
                        ])
         
     | 
| 73 | 
         
            +
                    # image proj model
         
     | 
| 74 | 
         
            +
                    self.image_proj_model = self.init_proj()
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    self.load_ip_adapter()
         
     | 
| 77 | 
         
            +
                    self.mlp_projection_layer = self.generate_projection_layer(config)
         
     | 
| 78 | 
         
            +
                    
         
     | 
| 79 | 
         
            +
                    print(config.mlp_projection_layer_ckpt_path, type(config.mlp_projection_layer_ckpt_path) )
         
     | 
| 80 | 
         
            +
                    if config.mlp_projection_layer_ckpt_path is not None:
         
     | 
| 81 | 
         
            +
                        self.load_projection_layer(config.mlp_projection_layer_ckpt_path)
         
     | 
| 82 | 
         
            +
                    
         
     | 
| 83 | 
         
            +
                def load_projection_layer(self, path):
         
     | 
| 84 | 
         
            +
                    load_model(self.mlp_projection_layer, path)
         
     | 
| 85 | 
         
            +
                    print("Projection layer loaded from", path)
         
     | 
| 86 | 
         
            +
                    
         
     | 
| 87 | 
         
            +
                def generate_projection_layer(self, config):
         
     | 
| 88 | 
         
            +
                    projection_layer = nn.ModuleList()
         
     | 
| 89 | 
         
            +
                    
         
     | 
| 90 | 
         
            +
                    for i in range(config.number_of_hidden_layers):
         
     | 
| 91 | 
         
            +
                        if i < config.number_of_hidden_layers // 2:
         
     | 
| 92 | 
         
            +
                            projection_layer.append(nn.Linear(config.alpha_clip_embed_dim, config.alpha_clip_embed_dim))
         
     | 
| 93 | 
         
            +
                            projection_layer.append(nn.LayerNorm(config.alpha_clip_embed_dim))
         
     | 
| 94 | 
         
            +
                        elif i == config.number_of_hidden_layers // 2:
         
     | 
| 95 | 
         
            +
                            projection_layer.append(nn.Linear(config.alpha_clip_embed_dim, config.ip_adapter_embed_dim))
         
     | 
| 96 | 
         
            +
                            projection_layer.append(nn.LayerNorm(config.ip_adapter_embed_dim))
         
     | 
| 97 | 
         
            +
                        else:
         
     | 
| 98 | 
         
            +
                            projection_layer.append(nn.Linear(config.ip_adapter_embed_dim, config.ip_adapter_embed_dim))
         
     | 
| 99 | 
         
            +
                            projection_layer.append(nn.LayerNorm(config.ip_adapter_embed_dim))
         
     | 
| 100 | 
         
            +
                        projection_layer.append(nn.GELU())
         
     | 
| 101 | 
         
            +
                        
         
     | 
| 102 | 
         
            +
                    projection_layer.append(nn.Linear(config.ip_adapter_embed_dim, config.ip_adapter_embed_dim))
         
     | 
| 103 | 
         
            +
                    
         
     | 
| 104 | 
         
            +
                    return nn.Sequential(*projection_layer).to(self.device).to(torch.float32)
         
     | 
| 105 | 
         
            +
                
         
     | 
| 106 | 
         
            +
                def init_proj(self):
         
     | 
| 107 | 
         
            +
                    image_proj_model = ImageProjModel(
         
     | 
| 108 | 
         
            +
                        cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
         
     | 
| 109 | 
         
            +
                        clip_embeddings_dim=self.clip_proj.config.projection_dim,
         
     | 
| 110 | 
         
            +
                        clip_extra_context_tokens=self.num_tokens,
         
     | 
| 111 | 
         
            +
                    ).to(self.device, dtype=torch.float32)
         
     | 
| 112 | 
         
            +
                    return image_proj_model
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                def set_ip_adapter(self):
         
     | 
| 115 | 
         
            +
                    unet = self.pipe.unet
         
     | 
| 116 | 
         
            +
                    attn_procs = {}
         
     | 
| 117 | 
         
            +
                    for name in unet.attn_processors.keys():
         
     | 
| 118 | 
         
            +
                        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
         
     | 
| 119 | 
         
            +
                        if name.startswith("mid_block"):
         
     | 
| 120 | 
         
            +
                            hidden_size = unet.config.block_out_channels[-1]
         
     | 
| 121 | 
         
            +
                        elif name.startswith("up_blocks"):
         
     | 
| 122 | 
         
            +
                            block_id = int(name[len("up_blocks.")])
         
     | 
| 123 | 
         
            +
                            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
         
     | 
| 124 | 
         
            +
                        elif name.startswith("down_blocks"):
         
     | 
| 125 | 
         
            +
                            block_id = int(name[len("down_blocks.")])
         
     | 
| 126 | 
         
            +
                            hidden_size = unet.config.block_out_channels[block_id]
         
     | 
| 127 | 
         
            +
                        if cross_attention_dim is None:
         
     | 
| 128 | 
         
            +
                            attn_procs[name] = AttnProcessor().to(self.device)
         
     | 
| 129 | 
         
            +
                        else:
         
     | 
| 130 | 
         
            +
                            attn_procs[name] = IPAttnProcessor(
         
     | 
| 131 | 
         
            +
                                hidden_size=hidden_size,
         
     | 
| 132 | 
         
            +
                                cross_attention_dim=cross_attention_dim,
         
     | 
| 133 | 
         
            +
                                scale=1.0,
         
     | 
| 134 | 
         
            +
                                num_tokens=self.num_tokens,
         
     | 
| 135 | 
         
            +
                            ).to(self.device, dtype=torch.float32)
         
     | 
| 136 | 
         
            +
                    unet.set_attn_processor(attn_procs)
         
     | 
| 137 | 
         
            +
                            
         
     | 
| 138 | 
         
            +
                def get_alpha_clip_embeds(self, pil_image, alpha):
         
     | 
| 139 | 
         
            +
                    clip_image = [self.alpha_clip_image_processor(image) for image in pil_image]
         
     | 
| 140 | 
         
            +
                    clip_image = torch.stack(clip_image).to(self.device, dtype=torch.float32)
         
     | 
| 141 | 
         
            +
                    masks = [self.mask_transform(mask) for mask in alpha]
         
     | 
| 142 | 
         
            +
                    masks = torch.stack(masks).to(self.device, dtype=torch.float32)
         
     | 
| 143 | 
         
            +
                    
         
     | 
| 144 | 
         
            +
                    return self.image_encoder(clip_image, masks)
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                def load_ip_adapter(self):
         
     | 
| 147 | 
         
            +
                    if os.path.splitext(self.ipadapter_ckpt)[-1] == ".safetensors":
         
     | 
| 148 | 
         
            +
                        state_dict = {"image_proj": {}, "ip_adapter": {}}
         
     | 
| 149 | 
         
            +
                        with safe_open(self.ipadapter_ckpt, framework="pt", device="cpu") as f:
         
     | 
| 150 | 
         
            +
                            for key in f.keys():
         
     | 
| 151 | 
         
            +
                                if key.startswith("image_proj."):
         
     | 
| 152 | 
         
            +
                                    state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
         
     | 
| 153 | 
         
            +
                                elif key.startswith("ip_adapter."):
         
     | 
| 154 | 
         
            +
                                    state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
         
     | 
| 155 | 
         
            +
                    else:
         
     | 
| 156 | 
         
            +
                        state_dict = torch.load(self.ipadapter_ckpt, map_location="cpu")
         
     | 
| 157 | 
         
            +
                    self.image_proj_model.load_state_dict(state_dict["image_proj"])
         
     | 
| 158 | 
         
            +
                    ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
         
     | 
| 159 | 
         
            +
                    ip_layers.load_state_dict(state_dict["ip_adapter"])
         
     | 
| 160 | 
         
            +
                    
         
     | 
| 161 | 
         
            +
                def get_complement_of_mask(self, mask):
         
     | 
| 162 | 
         
            +
                    return Image.fromarray((255 - np.array(mask[0])).astype(np.uint8))
         
     | 
| 163 | 
         
            +
                
         
     | 
| 164 | 
         
            +
                def clipaway_projection_block(self, bg_embeds, fg_embeds):
         
     | 
| 165 | 
         
            +
                    projected_vector_magnitude = bg_embeds[0].dot(fg_embeds[0]) / fg_embeds[0].norm()
         
     | 
| 166 | 
         
            +
                    projected_vector = projected_vector_magnitude * fg_embeds / fg_embeds.norm()
         
     | 
| 167 | 
         
            +
                    return bg_embeds - projected_vector
         
     | 
| 168 | 
         
            +
                
         
     | 
| 169 | 
         
            +
                def get_focused_embeddings(self, pil_image, alpha, use_projection_block=False):
         
     | 
| 170 | 
         
            +
                    # get focused alpha clip embeds
         
     | 
| 171 | 
         
            +
                    clip_image_embeds_fg = self.get_alpha_clip_embeds(pil_image, alpha) 
         
     | 
| 172 | 
         
            +
                    clip_image_embeds_bg = self.get_alpha_clip_embeds(pil_image, [self.get_complement_of_mask(alpha)])
         
     | 
| 173 | 
         
            +
                    
         
     | 
| 174 | 
         
            +
                    # mlp projection
         
     | 
| 175 | 
         
            +
                    projected_alpha_clip_embeds_fg = self.mlp_projection_layer(clip_image_embeds_fg)
         
     | 
| 176 | 
         
            +
                    projected_alpha_clip_embeds_bg = self.mlp_projection_layer(clip_image_embeds_bg)
         
     | 
| 177 | 
         
            +
                        
         
     | 
| 178 | 
         
            +
                    # ip adapter logic
         
     | 
| 179 | 
         
            +
                    image_prompt_embeds_fg = self.image_proj_model(projected_alpha_clip_embeds_fg)
         
     | 
| 180 | 
         
            +
                    image_prompt_embeds_bg = self.image_proj_model(projected_alpha_clip_embeds_bg)
         
     | 
| 181 | 
         
            +
                    uncond_image_prompt_embeds = self.image_proj_model(self.mlp_projection_layer(torch.zeros_like(clip_image_embeds_fg)))
         
     | 
| 182 | 
         
            +
                            
         
     | 
| 183 | 
         
            +
                    if use_projection_block:
         
     | 
| 184 | 
         
            +
                        # clipaway projection block
         
     | 
| 185 | 
         
            +
                        projected_alpha_clip_embeds = self.clipaway_projection_block(projected_alpha_clip_embeds_bg, projected_alpha_clip_embeds_fg)
         
     | 
| 186 | 
         
            +
                        image_prompt_embeds = self.image_proj_model(projected_alpha_clip_embeds)
         
     | 
| 187 | 
         
            +
                        return image_prompt_embeds, image_prompt_embeds_fg, image_prompt_embeds_bg, uncond_image_prompt_embeds
         
     | 
| 188 | 
         
            +
                    
         
     | 
| 189 | 
         
            +
                    return image_prompt_embeds_fg, image_prompt_embeds_bg, uncond_image_prompt_embeds
         
     | 
| 190 | 
         
            +
                    
         
     | 
| 191 | 
         
            +
                    
         
     | 
| 192 | 
         
            +
                def get_ipadapter_embeds(self, pil_image=None, alpha=None):
         
     | 
| 193 | 
         
            +
                    # get focused alpha clip embeds
         
     | 
| 194 | 
         
            +
                    clip_image_embeds_fg = self.get_alpha_clip_embeds(pil_image, alpha) 
         
     | 
| 195 | 
         
            +
                    clip_image_embeds_bg = self.get_alpha_clip_embeds(pil_image, [self.get_complement_of_mask(alpha)])
         
     | 
| 196 | 
         
            +
                    
         
     | 
| 197 | 
         
            +
                    # mlp projection
         
     | 
| 198 | 
         
            +
                    projected_alpha_clip_embeds_fg = self.mlp_projection_layer(clip_image_embeds_fg)
         
     | 
| 199 | 
         
            +
                    projected_alpha_clip_embeds_bg = self.mlp_projection_layer(clip_image_embeds_bg)
         
     | 
| 200 | 
         
            +
                    
         
     | 
| 201 | 
         
            +
                    # clipaway projection block
         
     | 
| 202 | 
         
            +
                    projected_alpha_clip_embeds = self.clipaway_projection_block(projected_alpha_clip_embeds_bg, projected_alpha_clip_embeds_fg)
         
     | 
| 203 | 
         
            +
                    
         
     | 
| 204 | 
         
            +
                    # ip adapter logic
         
     | 
| 205 | 
         
            +
                    image_prompt_embeds = self.image_proj_model(projected_alpha_clip_embeds)
         
     | 
| 206 | 
         
            +
                    uncond_image_prompt_embeds = self.image_proj_model(self.mlp_projection_layer(torch.zeros_like(clip_image_embeds_fg)))
         
     | 
| 207 | 
         
            +
                            
         
     | 
| 208 | 
         
            +
                    return image_prompt_embeds, uncond_image_prompt_embeds
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                
         
     | 
| 211 | 
         
            +
                def set_scale(self, scale):
         
     | 
| 212 | 
         
            +
                    for attn_processor in self.pipe.unet.attn_processors.values():
         
     | 
| 213 | 
         
            +
                        if isinstance(attn_processor, IPAttnProcessor):
         
     | 
| 214 | 
         
            +
                            attn_processor.scale = scale
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                @torch.inference_mode()
         
     | 
| 217 | 
         
            +
                def generate(
         
     | 
| 218 | 
         
            +
                    self,
         
     | 
| 219 | 
         
            +
                    pil_image=None,
         
     | 
| 220 | 
         
            +
                    alpha=None, 
         
     | 
| 221 | 
         
            +
                    prompt=None,
         
     | 
| 222 | 
         
            +
                    negative_prompt=None,
         
     | 
| 223 | 
         
            +
                    image_prompt_embeds=None,
         
     | 
| 224 | 
         
            +
                    uncond_image_prompt_embeds=None,
         
     | 
| 225 | 
         
            +
                    scale=1.0,
         
     | 
| 226 | 
         
            +
                    num_samples=1,
         
     | 
| 227 | 
         
            +
                    seed=None,
         
     | 
| 228 | 
         
            +
                    guidance_scale=7.5,
         
     | 
| 229 | 
         
            +
                    num_inference_steps=50,
         
     | 
| 230 | 
         
            +
                    **kwargs,
         
     | 
| 231 | 
         
            +
                ):
         
     | 
| 232 | 
         
            +
                    self.set_scale(scale)
         
     | 
| 233 | 
         
            +
                    num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                    if prompt is None:
         
     | 
| 236 | 
         
            +
                        prompt = "best quality, high quality"
         
     | 
| 237 | 
         
            +
                    if negative_prompt is None:
         
     | 
| 238 | 
         
            +
                        negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                    if not isinstance(prompt, List):
         
     | 
| 241 | 
         
            +
                        prompt = [prompt] * num_prompts
         
     | 
| 242 | 
         
            +
                    if not isinstance(negative_prompt, List):
         
     | 
| 243 | 
         
            +
                        negative_prompt = [negative_prompt] * num_prompts
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                    if image_prompt_embeds is None or uncond_image_prompt_embeds is None:
         
     | 
| 246 | 
         
            +
                        image_prompt_embeds, uncond_image_prompt_embeds= self.get_ipadapter_embeds(pil_image=pil_image, alpha=alpha)
         
     | 
| 247 | 
         
            +
                    else:
         
     | 
| 248 | 
         
            +
                        image_prompt_embeds = image_prompt_embeds.to(self.device)
         
     | 
| 249 | 
         
            +
                        uncond_image_prompt_embeds = uncond_image_prompt_embeds.to(self.device)
         
     | 
| 250 | 
         
            +
                        
         
     | 
| 251 | 
         
            +
                    bs_embed, seq_len, _ = image_prompt_embeds.shape
         
     | 
| 252 | 
         
            +
                    image_prompt_embeds = image_prompt_embeds.view(bs_embed, seq_len, -1)
         
     | 
| 253 | 
         
            +
                    uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed, seq_len, -1)
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                    with torch.inference_mode():
         
     | 
| 256 | 
         
            +
                        prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
         
     | 
| 257 | 
         
            +
                            prompt,
         
     | 
| 258 | 
         
            +
                            device=self.device,
         
     | 
| 259 | 
         
            +
                            num_images_per_prompt=num_samples,
         
     | 
| 260 | 
         
            +
                            do_classifier_free_guidance=True,
         
     | 
| 261 | 
         
            +
                            negative_prompt=negative_prompt,
         
     | 
| 262 | 
         
            +
                        )
         
     | 
| 263 | 
         
            +
                        prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
         
     | 
| 264 | 
         
            +
                        negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    generator = get_generator(seed, self.device)
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                    images = self.pipe(
         
     | 
| 269 | 
         
            +
                        prompt_embeds=prompt_embeds,
         
     | 
| 270 | 
         
            +
                        negative_prompt_embeds=negative_prompt_embeds,
         
     | 
| 271 | 
         
            +
                        guidance_scale=guidance_scale,
         
     | 
| 272 | 
         
            +
                        num_inference_steps=num_inference_steps,
         
     | 
| 273 | 
         
            +
                        generator=generator,
         
     | 
| 274 | 
         
            +
                        image=pil_image, 
         
     | 
| 275 | 
         
            +
                        mask_image=alpha,
         
     | 
| 276 | 
         
            +
                        **kwargs,
         
     | 
| 277 | 
         
            +
                    ).images
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
                    return images
         
     | 
| 280 | 
         
            +
             
     | 
    	
        model/resampler.py
    ADDED
    
    | 
         @@ -0,0 +1,158 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            taken from https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
         
     | 
| 3 | 
         
            +
            """
         
     | 
| 4 | 
         
            +
            import math
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            import torch.nn as nn
         
     | 
| 8 | 
         
            +
            from einops import rearrange
         
     | 
| 9 | 
         
            +
            from einops.layers.torch import Rearrange
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            # FFN
         
     | 
| 13 | 
         
            +
            def FeedForward(dim, mult=4):
         
     | 
| 14 | 
         
            +
                inner_dim = int(dim * mult)
         
     | 
| 15 | 
         
            +
                return nn.Sequential(
         
     | 
| 16 | 
         
            +
                    nn.LayerNorm(dim),
         
     | 
| 17 | 
         
            +
                    nn.Linear(dim, inner_dim, bias=False),
         
     | 
| 18 | 
         
            +
                    nn.GELU(),
         
     | 
| 19 | 
         
            +
                    nn.Linear(inner_dim, dim, bias=False),
         
     | 
| 20 | 
         
            +
                )
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            def reshape_tensor(x, heads):
         
     | 
| 24 | 
         
            +
                bs, length, width = x.shape
         
     | 
| 25 | 
         
            +
                # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
         
     | 
| 26 | 
         
            +
                x = x.view(bs, length, heads, -1)
         
     | 
| 27 | 
         
            +
                # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
         
     | 
| 28 | 
         
            +
                x = x.transpose(1, 2)
         
     | 
| 29 | 
         
            +
                # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
         
     | 
| 30 | 
         
            +
                x = x.reshape(bs, heads, length, -1)
         
     | 
| 31 | 
         
            +
                return x
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            class PerceiverAttention(nn.Module):
         
     | 
| 35 | 
         
            +
                def __init__(self, *, dim, dim_head=64, heads=8):
         
     | 
| 36 | 
         
            +
                    super().__init__()
         
     | 
| 37 | 
         
            +
                    self.scale = dim_head**-0.5
         
     | 
| 38 | 
         
            +
                    self.dim_head = dim_head
         
     | 
| 39 | 
         
            +
                    self.heads = heads
         
     | 
| 40 | 
         
            +
                    inner_dim = dim_head * heads
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    self.norm1 = nn.LayerNorm(dim)
         
     | 
| 43 | 
         
            +
                    self.norm2 = nn.LayerNorm(dim)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                    self.to_q = nn.Linear(dim, inner_dim, bias=False)
         
     | 
| 46 | 
         
            +
                    self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
         
     | 
| 47 | 
         
            +
                    self.to_out = nn.Linear(inner_dim, dim, bias=False)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                def forward(self, x, latents):
         
     | 
| 50 | 
         
            +
                    """
         
     | 
| 51 | 
         
            +
                    Args:
         
     | 
| 52 | 
         
            +
                        x (torch.Tensor): image features
         
     | 
| 53 | 
         
            +
                            shape (b, n1, D)
         
     | 
| 54 | 
         
            +
                        latent (torch.Tensor): latent features
         
     | 
| 55 | 
         
            +
                            shape (b, n2, D)
         
     | 
| 56 | 
         
            +
                    """
         
     | 
| 57 | 
         
            +
                    x = self.norm1(x)
         
     | 
| 58 | 
         
            +
                    latents = self.norm2(latents)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    b, l, _ = latents.shape
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    q = self.to_q(latents)
         
     | 
| 63 | 
         
            +
                    kv_input = torch.cat((x, latents), dim=-2)
         
     | 
| 64 | 
         
            +
                    k, v = self.to_kv(kv_input).chunk(2, dim=-1)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    q = reshape_tensor(q, self.heads)
         
     | 
| 67 | 
         
            +
                    k = reshape_tensor(k, self.heads)
         
     | 
| 68 | 
         
            +
                    v = reshape_tensor(v, self.heads)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    # attention
         
     | 
| 71 | 
         
            +
                    scale = 1 / math.sqrt(math.sqrt(self.dim_head))
         
     | 
| 72 | 
         
            +
                    weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards
         
     | 
| 73 | 
         
            +
                    weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
         
     | 
| 74 | 
         
            +
                    out = weight @ v
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    return self.to_out(out)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            class Resampler(nn.Module):
         
     | 
| 82 | 
         
            +
                def __init__(
         
     | 
| 83 | 
         
            +
                    self,
         
     | 
| 84 | 
         
            +
                    dim=1024,
         
     | 
| 85 | 
         
            +
                    depth=8,
         
     | 
| 86 | 
         
            +
                    dim_head=64,
         
     | 
| 87 | 
         
            +
                    heads=16,
         
     | 
| 88 | 
         
            +
                    num_queries=8,
         
     | 
| 89 | 
         
            +
                    embedding_dim=768,
         
     | 
| 90 | 
         
            +
                    output_dim=1024,
         
     | 
| 91 | 
         
            +
                    ff_mult=4,
         
     | 
| 92 | 
         
            +
                    max_seq_len: int = 257,  # CLIP tokens + CLS token
         
     | 
| 93 | 
         
            +
                    apply_pos_emb: bool = False,
         
     | 
| 94 | 
         
            +
                    num_latents_mean_pooled: int = 0,  # number of latents derived from mean pooled representation of the sequence
         
     | 
| 95 | 
         
            +
                ):
         
     | 
| 96 | 
         
            +
                    super().__init__()
         
     | 
| 97 | 
         
            +
                    self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    self.proj_in = nn.Linear(embedding_dim, dim)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    self.proj_out = nn.Linear(dim, output_dim)
         
     | 
| 104 | 
         
            +
                    self.norm_out = nn.LayerNorm(output_dim)
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    self.to_latents_from_mean_pooled_seq = (
         
     | 
| 107 | 
         
            +
                        nn.Sequential(
         
     | 
| 108 | 
         
            +
                            nn.LayerNorm(dim),
         
     | 
| 109 | 
         
            +
                            nn.Linear(dim, dim * num_latents_mean_pooled),
         
     | 
| 110 | 
         
            +
                            Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
         
     | 
| 111 | 
         
            +
                        )
         
     | 
| 112 | 
         
            +
                        if num_latents_mean_pooled > 0
         
     | 
| 113 | 
         
            +
                        else None
         
     | 
| 114 | 
         
            +
                    )
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    self.layers = nn.ModuleList([])
         
     | 
| 117 | 
         
            +
                    for _ in range(depth):
         
     | 
| 118 | 
         
            +
                        self.layers.append(
         
     | 
| 119 | 
         
            +
                            nn.ModuleList(
         
     | 
| 120 | 
         
            +
                                [
         
     | 
| 121 | 
         
            +
                                    PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
         
     | 
| 122 | 
         
            +
                                    FeedForward(dim=dim, mult=ff_mult),
         
     | 
| 123 | 
         
            +
                                ]
         
     | 
| 124 | 
         
            +
                            )
         
     | 
| 125 | 
         
            +
                        )
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                def forward(self, x):
         
     | 
| 128 | 
         
            +
                    if self.pos_emb is not None:
         
     | 
| 129 | 
         
            +
                        n, device = x.shape[1], x.device
         
     | 
| 130 | 
         
            +
                        pos_emb = self.pos_emb(torch.arange(n, device=device))
         
     | 
| 131 | 
         
            +
                        x = x + pos_emb
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    latents = self.latents.repeat(x.size(0), 1, 1)
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    x = self.proj_in(x)
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    if self.to_latents_from_mean_pooled_seq:
         
     | 
| 138 | 
         
            +
                        meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
         
     | 
| 139 | 
         
            +
                        meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
         
     | 
| 140 | 
         
            +
                        latents = torch.cat((meanpooled_latents, latents), dim=-2)
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                    for attn, ff in self.layers:
         
     | 
| 143 | 
         
            +
                        latents = attn(x, latents) + latents
         
     | 
| 144 | 
         
            +
                        latents = ff(latents) + latents
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    latents = self.proj_out(latents)
         
     | 
| 147 | 
         
            +
                    return self.norm_out(latents)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
            def masked_mean(t, *, dim, mask=None):
         
     | 
| 151 | 
         
            +
                if mask is None:
         
     | 
| 152 | 
         
            +
                    return t.mean(dim=dim)
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                denom = mask.sum(dim=dim, keepdim=True)
         
     | 
| 155 | 
         
            +
                mask = rearrange(mask, "b n -> b n 1")
         
     | 
| 156 | 
         
            +
                masked_t = t.masked_fill(~mask, 0.0)
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
         
     |