diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..24aa54d81a0af0eeffa744a6c81e8a79f5cb76a3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +*.pyc +__pycache__ +test.py +flagged +output +gradio_cached* +dist* +*egg-info +build* diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d56ccfbb7496de8592c9745d0d9a5e390af75fd6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,437 @@ +Attribution-NonCommercial-ShareAlike 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible.cp + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International +Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-ShareAlike 4.0 International Public License +("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. BY-NC-SA Compatible License means a license listed at + creativecommons.org/compatiblelicenses, approved by Creative + Commons as essentially the equivalent of this Public License. + + d. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + e. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + f. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + g. License Elements means the license attributes listed in the name + of a Creative Commons Public License. The License Elements of this + Public License are Attribution, NonCommercial, and ShareAlike. + + h. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + i. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + j. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + k. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + l. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + m. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + n. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. Additional offer from the Licensor -- Adapted Material. + Every recipient of Adapted Material from You + automatically receives an offer from the Licensor to + exercise the Licensed Rights in the Adapted Material + under the conditions of the Adapter's License You apply. + + c. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + b. ShareAlike. + + In addition to the conditions in Section 3(a), if You Share + Adapted Material You produce, the following conditions also apply. + + 1. The Adapter's License You apply must be a Creative Commons + license with the same License Elements, this version or + later, or a BY-NC-SA Compatible License. + + 2. You must include the text of, or the URI or hyperlink to, the + Adapter's License You apply. You may satisfy this condition + in any reasonable manner based on the medium, means, and + context in which You Share Adapted Material. + + 3. You may not offer or impose any additional or different terms + or conditions on, or apply any Effective Technological + Measures to, Adapted Material that restrict exercise of the + rights granted under the Adapter's License You apply. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material, + including for purposes of Section 3(b); and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..5dcbc139b80a520e73d210bb3236cb0a25a31129 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include *.py LICENSE README.md +recursive-include audioldm2 *.txt *.py *.gz *.npy *.json \ No newline at end of file diff --git a/README.md b/README.md index d8cd2b7e733713d278e699014b616f91add1cfaa..4d55fdfc128beb20be6f2512363a1c8663bff795 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,23 @@ --- -title: Audioldm2 Text2audio Text2music -emoji: 👁 -colorFrom: gray -colorTo: green +title: AudioLDM2 Text2Audio Text2Music Generation +emoji: 🔊 +colorFrom: indigo +colorTo: red sdk: gradio -sdk_version: 3.39.0 +sdk_version: 3.27.0 app_file: app.py pinned: false -license: cc-by-nc-nd-4.0 +license: bigscience-openrail-m +duplicated_from: haoheliu/audioldm2-text2audio-text2music + --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference + +## Reference +Part of the code from this repo is borrowed from the following repos. We would like to thank the authors of them for their contribution. + +> https://github.com/LAION-AI/CLAP +> https://github.com/CompVis/stable-diffusion +> https://github.com/v-iashin/SpecVQGAN +> https://github.com/toshas/torch-fidelity \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d2bba800105bda2f7fd171a68f28c31b00a87bd9 --- /dev/null +++ b/app.py @@ -0,0 +1,361 @@ +from huggingface_hub import hf_hub_download +import torch +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +import gradio as gr +from audioldm2 import text_to_audio, build_model +from share_btn import community_icon_html, loading_icon_html, share_js + +model_id = "haoheliu/audioldm2-full" +hf_hub_download(repo_id="haoheliu/audioldm2-full", filename="audioldm2-full.pth") + +audioldm = None +current_model_name = None + +def text2audio( + text, + guidance_scale, + random_seed, + n_candidates, + model_name="audioldm2-full", +): + global audioldm, current_model_name + torch.set_float32_matmul_precision("high") + + if audioldm is None or model_name != current_model_name: + audioldm = build_model(model_name=model_name) + current_model_name = model_name + audioldm = torch.compile(audioldm) + + # print(text, length, guidance_scale) + waveform = text_to_audio( + latent_diffusion=audioldm, + text=text, + seed=random_seed, + duration=10, + guidance_scale=guidance_scale, + n_candidate_gen_per_text=int(n_candidates), + ) # [bs, 1, samples] + waveform = [ + gr.make_waveform((16000, wave[0]), bg_image="bg.png") for wave in waveform + ] + # waveform = [(16000, np.random.randn(16000)), (16000, np.random.randn(16000))] + if len(waveform) == 1: + waveform = waveform[0] + return waveform + +css = """ + a { + color: inherit; + text-decoration: underline; + } + .gradio-container { + font-family: 'IBM Plex Sans', sans-serif; + } + .gr-button { + color: white; + border-color: #000000; + background: #000000; + } + input[type='range'] { + accent-color: #000000; + } + .dark input[type='range'] { + accent-color: #dfdfdf; + } + .container { + max-width: 730px; + margin: auto; + padding-top: 1.5rem; + } + #gallery { + min-height: 22rem; + margin-bottom: 15px; + margin-left: auto; + margin-right: auto; + border-bottom-right-radius: .5rem !important; + border-bottom-left-radius: .5rem !important; + } + #gallery>div>.h-full { + min-height: 20rem; + } + .details:hover { + text-decoration: underline; + } + .gr-button { + white-space: nowrap; + } + .gr-button:focus { + border-color: rgb(147 197 253 / var(--tw-border-opacity)); + outline: none; + box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); + --tw-border-opacity: 1; + --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); + --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color); + --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity)); + --tw-ring-opacity: .5; + } + #advanced-btn { + font-size: .7rem !important; + line-height: 19px; + margin-top: 12px; + margin-bottom: 12px; + padding: 2px 8px; + border-radius: 14px !important; + } + #advanced-options { + margin-bottom: 20px; + } + .footer { + margin-bottom: 45px; + margin-top: 35px; + text-align: center; + border-bottom: 1px solid #e5e5e5; + } + .footer>p { + font-size: .8rem; + display: inline-block; + padding: 0 10px; + transform: translateY(10px); + background: white; + } + .dark .footer { + border-color: #303030; + } + .dark .footer>p { + background: #0b0f19; + } + .acknowledgments h4{ + margin: 1.25em 0 .25em 0; + font-weight: bold; + font-size: 115%; + } + #container-advanced-btns{ + display: flex; + flex-wrap: wrap; + justify-content: space-between; + align-items: center; + } + .animate-spin { + animation: spin 1s linear infinite; + } + @keyframes spin { + from { + transform: rotate(0deg); + } + to { + transform: rotate(360deg); + } + } + #share-btn-container { + display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; + margin-top: 10px; + margin-left: auto; + } + #share-btn { + all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; + } + #share-btn * { + all: unset; + } + #share-btn-container div:nth-child(-n+2){ + width: auto !important; + min-height: 0px !important; + } + #share-btn-container .wrap { + display: none !important; + } + .gr-form{ + flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0; + } + #prompt-container{ + gap: 0; + } + #generated_id{ + min-height: 700px + } + #setting_id{ + margin-bottom: 12px; + text-align: center; + font-weight: 900; + } +""" +iface = gr.Blocks(css=css) + +with iface: + gr.HTML( + """ +
+
+

+ AudioLDM 2: A General Framework for Audio, Music, and Speech Generation +

+
+

+ [Paper] [Project page] +

+
+ """ + ) + gr.HTML( + """ +

+ AudioLDM 2: A General Framework for Audio, Music, and Speech Generation +

+

For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. +
+ + Duplicate Space +

+ """ + ) + with gr.Group(): + with gr.Box(): + ############# Input + textbox = gr.Textbox( + value="A forest of wind chimes singing a soothing melody in the breeze.", + max_lines=1, + label="Input your text here. Your text is important for the audio quality. Please ensure it is descriptive by using more adjectives.", + elem_id="prompt-in", + ) + + with gr.Accordion("Click to modify detailed configurations", open=False): + seed = gr.Number( + value=45, + label="Change this value (any integer number) will lead to a different generation result.", + ) + # duration = gr.Slider( + # 10, 10, value=10, step=2.5, label="Duration (seconds)" + # ) + guidance_scale = gr.Slider( + 0, + 6, + value=3.5, + step=0.5, + label="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)", + ) + n_candidates = gr.Slider( + 1, + 3, + value=3, + step=1, + label="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation", + ) + # model_name = gr.Dropdown( + # ["audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full","audioldm-s-full-v2", "audioldm-s-full", "audioldm-l-full"], value="audioldm-m-full", label="Choose the model to use. audioldm-m-text-ft and audioldm-s-text-ft are recommanded. -s- means small, -m- means medium and -l- means large", + # ) + ############# Output + # outputs=gr.Audio(label="Output", type="numpy") + outputs = gr.Video(label="Output", elem_id="output-video") + + # with gr.Group(elem_id="container-advanced-btns"): + # # advanced_button = gr.Button("Advanced options", elem_id="advanced-btn") + # with gr.Group(elem_id="share-btn-container"): + # community_icon = gr.HTML(community_icon_html, visible=False) + # loading_icon = gr.HTML(loading_icon_html, visible=False) + # share_button = gr.Button("Share to community", elem_id="share-btn", visible=False) + # outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")] + btn = gr.Button("Submit").style(full_width=True) + + with gr.Group(elem_id="share-btn-container", visible=False): + community_icon = gr.HTML(community_icon_html) + loading_icon = gr.HTML(loading_icon_html) + share_button = gr.Button("Share to community", elem_id="share-btn") + + # btn.click(text2audio, inputs=[ + # textbox, duration, guidance_scale, seed, n_candidates, model_name], outputs=[outputs]) + btn.click( + text2audio, + inputs=[textbox, guidance_scale, seed, n_candidates], + outputs=[outputs], + ) + + share_button.click(None, [], [], _js=share_js) + gr.HTML( + """ +

+ """ + ) + gr.Examples( + [ + [ + "An excited crowd cheering at a sports game.", + 3.5, + 45, + 3, + "audioldm2-full", + ], + [ + "A cat is meowing for attention.", + 3.5, + 45, + 3, + "audioldm2-full", + ], + [ + "Birds singing sweetly in a blooming garden.", + 3.5, + 45, + 3, + "audioldm2-full", + ], + [ + "A modern synthesizer creating futuristic soundscapes.", + 3.5, + 45, + 3, + "audioldm2-full", + ], + [ + "The vibrant beat of Brazilian samba drums.", + 3.5, + 45, + 3, + "audioldm2-full", + ], + ], + fn=text2audio, + # inputs=[textbox, duration, guidance_scale, seed, n_candidates, model_name], + inputs=[textbox, guidance_scale, seed, n_candidates], + outputs=[outputs], + cache_examples=True, + ) + gr.HTML( + """ +
+

Essential Tricks for Enhancing the Quality of Your Generated Audio

+

1. Try to use more adjectives to describe your sound. For example: "A man is speaking clearly and slowly in a large room" is better than "A man is speaking". This can make sure AudioLDM understands what you want.

+

2. Try to use different random seeds, which can affect the generation quality significantly sometimes.

+

3. It's better to use general terms like 'man' or 'woman' instead of specific names for individuals or abstract objects that humans may not be familiar with, such as 'mummy'.

+
+ """ + ) + + with gr.Accordion("Additional information", open=False): + gr.HTML( + """ +
+

We build the model with data from AudioSet, Freesound and BBC Sound Effect library. We share this demo based on the UK copyright exception of data for academic research.

+
+ """ + ) +#

This demo is strictly for research demo purpose only. For commercial use please contact us.

+ +iface.queue(concurrency_count=3) +# iface.launch(debug=True) +iface.launch(debug=True, share=True) diff --git a/audioldm2/__init__.py b/audioldm2/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..91befda907125b4772601b1df2c9a8a52b733735 --- /dev/null +++ b/audioldm2/__init__.py @@ -0,0 +1,2 @@ +from .utils import seed_everything, save_wave, get_time, get_duration, read_list +from .pipeline import * diff --git a/audioldm2/audiomae_gen/__init__.py b/audioldm2/audiomae_gen/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..7202889ac7aeb7e5b344da994206715cbbb3891e --- /dev/null +++ b/audioldm2/audiomae_gen/__init__.py @@ -0,0 +1 @@ +from .sequence_input import Sequence2AudioMAE diff --git a/audioldm2/audiomae_gen/sequence_input.py b/audioldm2/audiomae_gen/sequence_input.py new file mode 100755 index 0000000000000000000000000000000000000000..4d961a0dd7157689fab6291bb3c40d9bd656b5f1 --- /dev/null +++ b/audioldm2/audiomae_gen/sequence_input.py @@ -0,0 +1,429 @@ +import torch +import torch.nn as nn +from audioldm2.latent_diffusion.util import ( + instantiate_from_config, +) + +# from latent_diffusion.modules.encoders.modules import CLAPAudioEmbeddingClassifierFreev2 +from transformers import GPT2Config, GPT2Model +import torch.optim.lr_scheduler as lr_scheduler + +class Sequence2AudioMAE(nn.Module): + def __init__( + self, + base_learning_rate, + sequence_gen_length, + sequence_input_key, + sequence_input_embed_dim, + cond_stage_config, + optimizer_type="AdamW", + use_warmup=True, + use_ar_gen_loss=False, + use_audiomae_linear=False, + target_tokens_mask_ratio=0.0, + random_mask_ratio=False, + **kwargs + ): + super().__init__() + assert use_audiomae_linear == False + self.random_mask_ratio = random_mask_ratio + self.learning_rate = base_learning_rate + self.cond_stage_config = cond_stage_config + self.use_audiomae_linear = use_audiomae_linear + self.optimizer_type = optimizer_type + self.use_warmup = use_warmup + self.use_ar_gen_loss = use_ar_gen_loss + # Even though the LDM can be conditioned on mutliple pooling rate + # Our model always predict the higest pooling rate + + # self.time_pool = max(self.cond_stage_config["crossattn_audiomae_pooled"]["params"]["time_pooling_factors"]) + # self.freq_pool = max(self.cond_stage_config["crossattn_audiomae_pooled"]["params"]["freq_pooling_factors"]) + # self.mae_token_num = int(512/(self.time_pool*self.freq_pool)) + + self.mae_token_num = sequence_gen_length + self.sequence_input_key = sequence_input_key + self.sequence_input_embed_dim = sequence_input_embed_dim + self.target_tokens_mask_ratio = target_tokens_mask_ratio + + self.start_of_sequence_tokens = nn.Embedding(32, 768) + self.end_of_sequence_tokens = nn.Embedding(32, 768) + + self.input_sequence_embed_linear = nn.ModuleList([]) + self.initial_learning_rate = None + + for dim in self.sequence_input_embed_dim: + self.input_sequence_embed_linear.append(nn.Linear(dim, 768)) + + self.cond_stage_models = nn.ModuleList([]) + self.instantiate_cond_stage(cond_stage_config) + self.initialize_param_check_toolkit() + + # configuration = GPT2Config(n_layer=1) # TODO + # self.model=GPT2Model(configuration) + ################### + # self.model=nn.Linear(768,768, bias=False) # TODO change the model + # with torch.no_grad(): + # self.model.weight.copy_(torch.eye(768)) + ################### + self.model = GPT2Model(GPT2Config.from_pretrained("gpt2")) + ################### + # self.model = nn.LSTM(input_size=768, hidden_size=768, num_layers=1,bias=False) # TODO + + # self.loss_fn = nn.MSELoss() + self.loss_fn = nn.L1Loss() + + self.logger_save_dir = None + self.logger_exp_name = None + self.logger_exp_group_name = None + self.logger_version = None + + def set_log_dir(self, save_dir, exp_group_name, exp_name): + self.logger_save_dir = save_dir + self.logger_exp_group_name = exp_group_name + self.logger_exp_name = exp_name + + def cfg_uncond(self, batch_size): + unconditional_conditioning = {} + for key in self.cond_stage_model_metadata: + model_idx = self.cond_stage_model_metadata[key]["model_idx"] + unconditional_conditioning[key] = self.cond_stage_models[ + model_idx + ].get_unconditional_condition(batch_size) + assert ( + "crossattn_audiomae_pooled" in unconditional_conditioning.keys() + ), "The module is not initialized with AudioMAE" + unconditional_conditioning[ + "crossattn_clap_to_audiomae_feature" + ] = unconditional_conditioning["crossattn_audiomae_pooled"] + return unconditional_conditioning + + def configure_optimizers(self): + lr = float(self.learning_rate) + # params = list(self.model.parameters()) + list(self.input_sequence_embed_linear.parameters()) + params = list(self.parameters()) + + # opt = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.98), eps=1e-9) + opt = eval(self.optimizer_type)(params, lr=lr) + scheduler = lr_scheduler.StepLR(opt, step_size=10, gamma=0.8) + return [opt], [scheduler] + + def add_sos_eos_tokens(self, _id, sequence, attn_mask): + batchsize = sequence.size(0) + + new_attn_mask_step = torch.ones((batchsize, 1)).to(sequence.device) + key_id = torch.tensor([_id]).to(sequence.device) + + # Add two more steps to attn mask + new_attn_mask = torch.cat( + [new_attn_mask_step, attn_mask, new_attn_mask_step], dim=1 + ) + + # Add two more tokens in the sequence + sos_token = self.start_of_sequence_tokens(key_id).expand(batchsize, 1, -1) + eos_token = self.end_of_sequence_tokens(key_id).expand(batchsize, 1, -1) + new_sequence = torch.cat([sos_token, sequence, eos_token], dim=1) + return new_sequence, new_attn_mask + + def truncate_sequence_and_mask(self, sequence, mask, max_len=512): + if sequence.size(1) > max_len: + print( + "The input sequence length to GPT-2 model is too long:", + sequence.size(1), + ) + return sequence[:, :max_len], mask[:, :max_len] + else: + return sequence, mask + + def get_input_sequence_and_mask(self, cond_dict): + input_embeds = None + input_embeds_attn_mask = None + for _id, sequence_key in enumerate(self.sequence_input_key): + assert sequence_key in cond_dict.keys(), ( + "Invalid sequence key %s" % sequence_key + ) + cond_embed = cond_dict[sequence_key] + if isinstance(cond_embed, list): + assert ( + len(cond_embed) == 2 + ), "The crossattn returned list should have length 2, including embed and attn_mask" + item_input_embeds, item_attn_mask = cond_embed + + item_input_embeds = self.input_sequence_embed_linear[_id]( + item_input_embeds + ) + + item_input_embeds, item_attn_mask = self.add_sos_eos_tokens( + _id, item_input_embeds, item_attn_mask + ) + + if input_embeds is None and input_embeds_attn_mask is None: + input_embeds, input_embeds_attn_mask = ( + item_input_embeds, + item_attn_mask, + ) + else: + input_embeds = torch.cat( + [input_embeds, item_input_embeds], dim=1 + ) # The 1-st dimension is time steps + input_embeds_attn_mask = torch.cat( + [input_embeds_attn_mask, item_attn_mask], dim=1 + ) # The 1-st dimension is time steps + else: + assert isinstance(cond_embed, torch.Tensor) + cond_embed = self.input_sequence_embed_linear[_id](cond_embed) + attn_mask = torch.ones((cond_embed.size(0), cond_embed.size(1))).to( + cond_embed.device + ) + + item_input_embeds, item_attn_mask = self.add_sos_eos_tokens( + _id, cond_embed, attn_mask + ) + + if input_embeds is None and input_embeds_attn_mask is None: + input_embeds, input_embeds_attn_mask = ( + item_input_embeds, + item_attn_mask, + ) + else: + input_embeds, input_embeds_attn_mask = torch.cat( + [input_embeds, item_input_embeds], dim=1 + ), torch.cat([input_embeds_attn_mask, item_attn_mask], dim=1) + + assert input_embeds is not None and input_embeds_attn_mask is not None + + input_embeds, input_embeds_attn_mask = self.truncate_sequence_and_mask( + input_embeds, input_embeds_attn_mask, int(1024 - self.mae_token_num) + ) + cond_sequence_end_time_idx = input_embeds.size( + 1 + ) # The index that we start to collect the output embeds + + return input_embeds, input_embeds_attn_mask, cond_sequence_end_time_idx + + def warmup_step(self): + if self.initial_learning_rate is None: + self.initial_learning_rate = float(self.learning_rate) + + # Only the first parameter group + if self.global_step <= 1000: + if self.global_step == 0: + print( + "Warming up learning rate start with %s" + % self.initial_learning_rate + ) + self.trainer.optimizers[0].param_groups[0]["lr"] = ( + self.global_step / 1000 + ) * self.initial_learning_rate + else: + # TODO set learning rate here + self.trainer.optimizers[0].param_groups[0][ + "lr" + ] = self.initial_learning_rate + + def mask_target_sequence(self, target_embeds, target_embeds_attn_mask): + time_seq_mask = None + if self.target_tokens_mask_ratio > 1e-4: + batchsize, time_seq_len, embed_dim = target_embeds.size() + _, time_seq_len = target_embeds_attn_mask.size() + # Generate random mask + if self.random_mask_ratio: + mask_ratio = torch.rand(1).item() * self.target_tokens_mask_ratio + else: + mask_ratio = self.target_tokens_mask_ratio + + time_seq_mask = (torch.rand((batchsize, time_seq_len)) > mask_ratio).to( + target_embeds.device + ) + # Mask the target embedding + target_embeds = target_embeds * time_seq_mask.unsqueeze(-1) + target_embeds_attn_mask = target_embeds_attn_mask * time_seq_mask + return target_embeds, target_embeds_attn_mask, time_seq_mask + + def generate_partial(self, batch, cond_dict=None, no_grad=False): + if cond_dict is None: + cond_dict = self.get_input(batch) + + print("Generate partially prompted audio with in-context learning") + # self.model.train() + # assert self.model.training==True + + target_embeds, target_embeds_attn_mask = ( + cond_dict["crossattn_audiomae_pooled"][0], + cond_dict["crossattn_audiomae_pooled"][1], + ) + + target_time_steps = target_embeds.size(1) + + ( + input_embeds, + input_embeds_attn_mask, + cond_sequence_end_time_idx, + ) = self.get_input_sequence_and_mask(cond_dict) + + model_input = torch.cat( + [input_embeds, target_embeds[:, : target_time_steps // 4, :]], dim=1 + ) + model_input_mask = torch.cat( + [ + input_embeds_attn_mask, + target_embeds_attn_mask[:, : target_time_steps // 4], + ], + dim=1, + ) + + steps = self.mae_token_num + + for _ in range(3 * steps // 4): + output = self.model( + inputs_embeds=model_input, attention_mask=model_input_mask + )["last_hidden_state"] + # Update the model input + model_input = torch.cat([model_input, output[:, -1:, :]], dim=1) + # Update the attention mask + attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to( + model_input.device + ) + model_input_mask = torch.cat( + [model_input_mask, attention_mask_new_step], dim=1 + ) + + output = model_input[:, cond_sequence_end_time_idx:] + + return output, cond_dict + + def generate(self, batch, cond_dict=None, no_grad=False): + if cond_dict is None: + cond_dict = self.get_input(batch) + + # self.model.train() + # print("!!!!!!!!!!!!!train") + + ( + input_embeds, + input_embeds_attn_mask, + cond_sequence_end_time_idx, + ) = self.get_input_sequence_and_mask(cond_dict) + model_input = input_embeds + model_input_mask = input_embeds_attn_mask + + steps = self.mae_token_num + + for _ in range(steps): + output = self.model( + inputs_embeds=model_input, attention_mask=model_input_mask + )["last_hidden_state"] + # Update the model input + model_input = torch.cat([model_input, output[:, -1:, :]], dim=1) + # Update the attention mask + attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to( + model_input.device + ) + model_input_mask = torch.cat( + [model_input_mask, attention_mask_new_step], dim=1 + ) + + return model_input[:, cond_sequence_end_time_idx:], cond_dict + + def get_input_item(self, batch, k): + fname, text, waveform, stft, fbank = ( + batch["fname"], + batch["text"], + batch["waveform"], + batch["stft"], + batch["log_mel_spec"], + ) + ret = {} + + ret["fbank"] = ( + fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float() + ) + ret["stft"] = stft.to(memory_format=torch.contiguous_format).float() + # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float() + ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float() + ret["text"] = list(text) + ret["fname"] = fname + + for key in batch.keys(): + if key not in ret.keys(): + ret[key] = batch[key] + + return ret[k] + + def get_input(self, batch): + cond_dict = {} + if len(self.cond_stage_model_metadata.keys()) > 0: + unconditional_cfg = False + + for cond_model_key in self.cond_stage_model_metadata.keys(): + cond_stage_key = self.cond_stage_model_metadata[cond_model_key][ + "cond_stage_key" + ] + + # if(not self.training): + # if(isinstance(self.cond_stage_models[self.cond_stage_model_metadata[cond_model_key]["model_idx"]], CLAPAudioEmbeddingClassifierFreev2)): + # assert cond_stage_key == "text" # CLAP model should use text for evaluation + + # The original data for conditioning + xc = self.get_input_item(batch, cond_stage_key) + if type(xc) == torch.Tensor: + xc = xc.to(self.device) + + c = self.get_learned_conditioning( + xc, key=cond_model_key, unconditional_cfg=unconditional_cfg + ) + cond_dict[cond_model_key] = c + + return cond_dict + + def instantiate_cond_stage(self, config): + self.cond_stage_model_metadata = {} + + for i, cond_model_key in enumerate(config.keys()): + model = instantiate_from_config(config[cond_model_key]) + self.cond_stage_models.append(model) + self.cond_stage_model_metadata[cond_model_key] = { + "model_idx": i, + "cond_stage_key": config[cond_model_key]["cond_stage_key"], + "conditioning_key": config[cond_model_key]["conditioning_key"], + } + + def get_learned_conditioning(self, c, key, unconditional_cfg): + assert key in self.cond_stage_model_metadata.keys() + + # Classifier-free guidance + if not unconditional_cfg: + c = self.cond_stage_models[ + self.cond_stage_model_metadata[key]["model_idx"] + ](c) + else: + if isinstance(c, torch.Tensor): + batchsize = c.size(0) + elif isinstance(c, list): + batchsize = len(c) + else: + raise NotImplementedError() + c = self.cond_stage_models[ + self.cond_stage_model_metadata[key]["model_idx"] + ].get_unconditional_condition(batchsize) + + return c + + def initialize_param_check_toolkit(self): + self.tracked_steps = 0 + self.param_dict = {} + + def statistic_require_grad_tensor_number(self, module, name=None): + requires_grad_num = 0 + total_num = 0 + require_grad_tensor = None + for p in module.parameters(): + if p.requires_grad: + requires_grad_num += 1 + if require_grad_tensor is None: + require_grad_tensor = p + total_num += 1 + print( + "Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)" + % (name, requires_grad_num, total_num, requires_grad_num / total_num) + ) + return require_grad_tensor diff --git a/audioldm2/audiomae_gen/utils.py b/audioldm2/audiomae_gen/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..841d35adf338647bdf8bd1c31e9f33dee1252b6e --- /dev/null +++ b/audioldm2/audiomae_gen/utils.py @@ -0,0 +1,27 @@ +import torch.nn as nn + + +class Prenet(nn.Module): + def __init__(self, in_dim, sizes=[256, 128], dropout_rate=0.5): + super(Prenet, self).__init__() + in_sizes = [in_dim] + sizes[:-1] + self.layers = nn.ModuleList( + [ + nn.Linear(in_size, out_size) + for (in_size, out_size) in zip(in_sizes, sizes) + ] + ) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, inputs): + for linear in self.layers: + inputs = self.dropout(self.relu(linear(inputs))) + return inputs + + +if __name__ == "__main__": + model = Prenet(in_dim=128, sizes=[256, 256, 128]) + import ipdb + + ipdb.set_trace() diff --git a/audioldm2/clap/__init__.py b/audioldm2/clap/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm2/clap/open_clip/__init__.py b/audioldm2/clap/open_clip/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e9f728f2f273be5d5fdbec6c6cc41d737176a8c0 --- /dev/null +++ b/audioldm2/clap/open_clip/__init__.py @@ -0,0 +1,25 @@ +from .factory import ( + list_models, + create_model, + create_model_and_transforms, + add_model_config, +) +from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics +from .model import ( + CLAP, + CLAPTextCfg, + CLAPVisionCfg, + CLAPAudioCfp, + convert_weights_to_fp16, + trace_model, +) +from .openai import load_openai_model, list_openai_models +from .pretrained import ( + list_pretrained, + list_pretrained_tag_models, + list_pretrained_model_tags, + get_pretrained_url, + download_pretrained, +) +from .tokenizer import SimpleTokenizer, tokenize +from .transform import image_transform diff --git a/audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz b/audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100755 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/audioldm2/clap/open_clip/factory.py b/audioldm2/clap/open_clip/factory.py new file mode 100755 index 0000000000000000000000000000000000000000..df0f4a194c2e7328f7b7d3fe11fa6801c6cc1a7c --- /dev/null +++ b/audioldm2/clap/open_clip/factory.py @@ -0,0 +1,276 @@ +import json +import logging +import os +import re +from copy import deepcopy +from pathlib import Path + +import torch + +from .model import CLAP, convert_weights_to_fp16 +from .openai import load_openai_model +from .pretrained import get_pretrained_url, download_pretrained +from .transform import image_transform + +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = (".json",) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f"*{ext}")) + + for cf in config_files: + if os.path.basename(cf)[0] == ".": + continue # Ignore hidden files + + with open(cf, "r") as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = { + k: v + for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) + } + + +_rescan_model_configs() # initial populate of model config registry + + +def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + if skip_params: + if next(iter(state_dict.items()))[0].startswith("module"): + state_dict = {k[7:]: v for k, v in state_dict.items()} + # for k in state_dict: + # if k.startswith('transformer'): + # v = state_dict.pop(k) + # state_dict['text_branch.' + k[12:]] = v + return state_dict + + +def create_model( + amodel_name: str, + tmodel_name: str, + pretrained: str = "", + precision: str = "fp32", + device: torch.device = torch.device("cpu"), + jit: bool = False, + force_quick_gelu: bool = False, + openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"), + skip_params=True, + pretrained_audio: str = "", + pretrained_text: str = "", + enable_fusion: bool = False, + fusion_type: str = "None" + # pretrained_image: bool = False, +): + amodel_name = amodel_name.replace( + "/", "-" + ) # for callers using old naming with / in ViT names + pretrained_orig = pretrained + pretrained = pretrained.lower() + if pretrained == "openai": + if amodel_name in _MODEL_CONFIGS: + logging.info(f"Loading {amodel_name} model config.") + model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) + else: + logging.error( + f"Model config for {amodel_name} not found; available models {list_models()}." + ) + raise RuntimeError(f"Model config for {amodel_name} not found.") + + logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.") + # Hard Code in model name + model_cfg["text_cfg"]["model_type"] = tmodel_name + model = load_openai_model( + "ViT-B-16", + model_cfg, + device=device, + jit=jit, + cache_dir=openai_model_cache_dir, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 + if precision == "amp" or precision == "fp32": + model = model.float() + else: + if amodel_name in _MODEL_CONFIGS: + logging.info(f"Loading {amodel_name} model config.") + model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) + else: + logging.error( + f"Model config for {amodel_name} not found; available models {list_models()}." + ) + raise RuntimeError(f"Model config for {amodel_name} not found.") + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + # if pretrained_image: + # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}): + # # pretrained weight loading for timm models set via vision_cfg + # model_cfg['vision_cfg']['timm_model_pretrained'] = True + # else: + # assert False, 'pretrained image towers currently only supported for timm models' + model_cfg["text_cfg"]["model_type"] = tmodel_name + model_cfg["enable_fusion"] = enable_fusion + model_cfg["fusion_type"] = fusion_type + model = CLAP(**model_cfg) + + if pretrained: + checkpoint_path = "" + url = get_pretrained_url(amodel_name, pretrained) + if url: + checkpoint_path = download_pretrained(url, root=openai_model_cache_dir) + elif os.path.exists(pretrained_orig): + checkpoint_path = pretrained_orig + if checkpoint_path: + logging.info( + f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})." + ) + ckpt = load_state_dict(checkpoint_path, skip_params=True) + model.load_state_dict(ckpt) + param_names = [n for n, p in model.named_parameters()] + # for n in param_names: + # print(n, "\t", "Loaded" if n in ckpt else "Unloaded") + else: + logging.warning( + f"Pretrained weights ({pretrained}) not found for model {amodel_name}." + ) + raise RuntimeError( + f"Pretrained weights ({pretrained}) not found for model {amodel_name}." + ) + + if pretrained_audio: + if amodel_name.startswith("PANN"): + if "Cnn14_mAP" in pretrained_audio: # official checkpoint + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["model"] + keys = list(audio_ckpt.keys()) + for key in keys: + if ( + "spectrogram_extractor" not in key + and "logmel_extractor" not in key + ): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key] = v + elif os.path.basename(pretrained_audio).startswith( + "PANN" + ): # checkpoint trained via HTSAT codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["state_dict"] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith("sed_model"): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith( + "finetuned" + ): # checkpoint trained via linear probe codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + else: + raise ValueError("Unknown audio checkpoint") + elif amodel_name.startswith("HTSAT"): + if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["state_dict"] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith("sed_model") and ( + "spectrogram_extractor" not in key + and "logmel_extractor" not in key + ): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith( + "HTSAT" + ): # checkpoint trained via HTSAT codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["state_dict"] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith("sed_model"): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith( + "finetuned" + ): # checkpoint trained via linear probe codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + else: + raise ValueError("Unknown audio checkpoint") + else: + raise f"this audio encoder pretrained checkpoint is not support" + + model.load_state_dict(audio_ckpt, strict=False) + logging.info( + f"Loading pretrained {amodel_name} weights ({pretrained_audio})." + ) + param_names = [n for n, p in model.named_parameters()] + for n in param_names: + print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded") + + model.to(device=device) + if precision == "fp16": + assert device.type != "cpu" + convert_weights_to_fp16(model) + + if jit: + model = torch.jit.script(model) + + return model, model_cfg + + +def create_model_and_transforms( + model_name: str, + pretrained: str = "", + precision: str = "fp32", + device: torch.device = torch.device("cpu"), + jit: bool = False, + force_quick_gelu: bool = False, + # pretrained_image: bool = False, +): + model = create_model( + model_name, + pretrained, + precision, + device, + jit, + force_quick_gelu=force_quick_gelu, + # pretrained_image=pretrained_image + ) + preprocess_train = image_transform(model.visual.image_size, is_train=True) + preprocess_val = image_transform(model.visual.image_size, is_train=False) + return model, preprocess_train, preprocess_val + + +def list_models(): + """enumerate available model architectures based on config files""" + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """add model config path or file and update registry""" + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() diff --git a/audioldm2/clap/open_clip/feature_fusion.py b/audioldm2/clap/open_clip/feature_fusion.py new file mode 100755 index 0000000000000000000000000000000000000000..dbe4e170e05894c12ebdc36ba1dc1de65e441b89 --- /dev/null +++ b/audioldm2/clap/open_clip/feature_fusion.py @@ -0,0 +1,192 @@ +""" +Feature Fusion for Varible-Length Data Processing +AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py +According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021 +""" + +import torch +import torch.nn as nn + + +class DAF(nn.Module): + """ + 直接相加 DirectAddFuse + """ + + def __init__(self): + super(DAF, self).__init__() + + def forward(self, x, residual): + return x + residual + + +class iAFF(nn.Module): + """ + 多特征融合 iAFF + """ + + def __init__(self, channels=64, r=4, type="2D"): + super(iAFF, self).__init__() + inter_channels = int(channels // r) + + if type == "1D": + # 本地注意力 + self.local_att = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + + # 全局注意力 + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + + # 第二次本地注意力 + self.local_att2 = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + # 第二次全局注意力 + self.global_att2 = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + elif type == "2D": + # 本地注意力 + self.local_att = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + # 全局注意力 + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + # 第二次本地注意力 + self.local_att2 = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + # 第二次全局注意力 + self.global_att2 = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + else: + raise f"the type is not supported" + + self.sigmoid = nn.Sigmoid() + + def forward(self, x, residual): + flag = False + xa = x + residual + if xa.size(0) == 1: + xa = torch.cat([xa, xa], dim=0) + flag = True + xl = self.local_att(xa) + xg = self.global_att(xa) + xlg = xl + xg + wei = self.sigmoid(xlg) + xi = x * wei + residual * (1 - wei) + + xl2 = self.local_att2(xi) + xg2 = self.global_att(xi) + xlg2 = xl2 + xg2 + wei2 = self.sigmoid(xlg2) + xo = x * wei2 + residual * (1 - wei2) + if flag: + xo = xo[0].unsqueeze(0) + return xo + + +class AFF(nn.Module): + """ + 多特征融合 AFF + """ + + def __init__(self, channels=64, r=4, type="2D"): + super(AFF, self).__init__() + inter_channels = int(channels // r) + + if type == "1D": + self.local_att = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + elif type == "2D": + self.local_att = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + else: + raise f"the type is not supported." + + self.sigmoid = nn.Sigmoid() + + def forward(self, x, residual): + flag = False + xa = x + residual + if xa.size(0) == 1: + xa = torch.cat([xa, xa], dim=0) + flag = True + xl = self.local_att(xa) + xg = self.global_att(xa) + xlg = xl + xg + wei = self.sigmoid(xlg) + xo = 2 * x * wei + 2 * residual * (1 - wei) + if flag: + xo = xo[0].unsqueeze(0) + return xo diff --git a/audioldm2/clap/open_clip/htsat.py b/audioldm2/clap/open_clip/htsat.py new file mode 100755 index 0000000000000000000000000000000000000000..8bf4fceea2dfef953522c14a3a39a417658f2257 --- /dev/null +++ b/audioldm2/clap/open_clip/htsat.py @@ -0,0 +1,1304 @@ +# Ke Chen +# knutchen@ucsd.edu +# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION +# Some layers designed on the model +# below codes are based and referred from https://github.com/microsoft/Swin-Transformer +# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf + +import torch +import torch.nn as nn +from itertools import repeat +import collections.abc +import math +import warnings + +from torch.nn.init import _calculate_fan_in_and_fan_out +import torch.utils.checkpoint as checkpoint + +import random + +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from itertools import repeat +from .utils import do_mixup, interpolate + +from .feature_fusion import iAFF, AFF, DAF + + +# from PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + patch_stride=16, + enable_fusion=False, + fusion_type="None", + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patch_stride = to_2tuple(patch_stride) + self.img_size = img_size + self.patch_size = patch_size + self.patch_stride = patch_stride + self.grid_size = ( + img_size[0] // patch_stride[0], + img_size[1] // patch_stride[1], + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + padding = ( + (patch_size[0] - patch_stride[0]) // 2, + (patch_size[1] - patch_stride[1]) // 2, + ) + + if (self.enable_fusion) and (self.fusion_type == "channel_map"): + self.proj = nn.Conv2d( + in_chans * 4, + embed_dim, + kernel_size=patch_size, + stride=patch_stride, + padding=padding, + ) + else: + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_stride, + padding=padding, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + self.mel_conv2d = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=(patch_size[0], patch_size[1] * 3), + stride=(patch_stride[0], patch_stride[1] * 3), + padding=padding, + ) + if self.fusion_type == "daf_2d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_2d": + self.fusion_model = AFF(channels=embed_dim, type="2D") + elif self.fusion_type == "iaff_2d": + self.fusion_model = iAFF(channels=embed_dim, type="2D") + + def forward(self, x, longer_idx=None): + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + global_x = x[:, 0:1, :, :] + + # global processing + B, C, H, W = global_x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + global_x = self.proj(global_x) + TW = global_x.size(-1) + if len(longer_idx) > 0: + # local processing + local_x = x[longer_idx, 1:, :, :].contiguous() + B, C, H, W = local_x.shape + local_x = local_x.view(B * C, 1, H, W) + local_x = self.mel_conv2d(local_x) + local_x = local_x.view( + B, C, local_x.size(1), local_x.size(2), local_x.size(3) + ) + local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3) + TB, TC, TH, _ = local_x.size() + if local_x.size(-1) < TW: + local_x = torch.cat( + [ + local_x, + torch.zeros( + (TB, TC, TH, TW - local_x.size(-1)), + device=global_x.device, + ), + ], + dim=-1, + ) + else: + local_x = local_x[:, :, :, :TW] + + global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x) + x = global_x + else: + B, C, H, W = x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r"""Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + def extra_repr(self): + return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}" + + +# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model +class SwinTransformerBlock(nn.Module): + r"""Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + norm_before_mlp="ln", + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.norm_before_mlp = norm_before_mlp + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + if self.norm_before_mlp == "ln": + self.norm2 = nn.LayerNorm(dim) + elif self.norm_before_mlp == "bn": + self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose( + 1, 2 + ) + else: + raise NotImplementedError + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill( + attn_mask != 0, float(-100.0) + ).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + # pdb.set_trace() + H, W = self.input_resolution + # print("H: ", H) + # print("W: ", W) + # pdb.set_trace() + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) + else: + shifted_x = x + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows, attn = self.attn( + x_windows, mask=self.attn_mask + ) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x, attn + + def extra_repr(self): + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + ) + + +class PatchMerging(nn.Module): + r"""Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self): + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + norm_before_mlp="ln", + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + norm_layer=norm_layer, + norm_before_mlp=norm_before_mlp, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer + ) + else: + self.downsample = None + + def forward(self, x): + attns = [] + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x, attn = blk(x) + if not self.training: + attns.append(attn.unsqueeze(0)) + if self.downsample is not None: + x = self.downsample(x) + if not self.training: + attn = torch.cat(attns, dim=0) + attn = torch.mean(attn, dim=0) + return x, attn + + def extra_repr(self): + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +# The Core of HTSAT +class HTSAT_Swin_Transformer(nn.Module): + r"""HTSAT based on the Swin Transformer + Args: + spec_size (int | tuple(int)): Input Spectrogram size. Default 256 + patch_size (int | tuple(int)): Patch size. Default: 4 + path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4 + in_chans (int): Number of input image channels. Default: 1 (mono) + num_classes (int): Number of classes for classification head. Default: 527 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 8 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + config (module): The configuration Module from config.py + """ + + def __init__( + self, + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + in_chans=1, + num_classes=527, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + norm_before_mlp="ln", + config=None, + enable_fusion=False, + fusion_type="None", + **kwargs, + ): + super(HTSAT_Swin_Transformer, self).__init__() + + self.config = config + self.spec_size = spec_size + self.patch_stride = patch_stride + self.patch_size = patch_size + self.window_size = window_size + self.embed_dim = embed_dim + self.depths = depths + self.ape = ape + self.in_chans = in_chans + self.num_classes = num_classes + self.num_heads = num_heads + self.num_layers = len(self.depths) + self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1)) + + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + + self.qkv_bias = qkv_bias + self.qk_scale = None + + self.patch_norm = patch_norm + self.norm_layer = norm_layer if self.patch_norm else None + self.norm_before_mlp = norm_before_mlp + self.mlp_ratio = mlp_ratio + + self.use_checkpoint = use_checkpoint + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # process mel-spec ; used only once + self.freq_ratio = self.spec_size // self.config.mel_bins + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + self.interpolate_ratio = 32 # Downsampled ratio + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=config.window_size, + hop_length=config.hop_size, + win_length=config.window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=config.sample_rate, + n_fft=config.window_size, + n_mels=config.mel_bins, + fmin=config.fmin, + fmax=config.fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) # 2 2 + self.bn0 = nn.BatchNorm2d(self.config.mel_bins) + + # split spctrogram into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=self.spec_size, + patch_size=self.patch_size, + in_chans=self.in_chans, + embed_dim=self.embed_dim, + norm_layer=self.norm_layer, + patch_stride=patch_stride, + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.grid_size + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dim) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=self.drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(self.embed_dim * 2**i_layer), + input_resolution=( + patches_resolution[0] // (2**i_layer), + patches_resolution[1] // (2**i_layer), + ), + depth=self.depths[i_layer], + num_heads=self.num_heads[i_layer], + window_size=self.window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + qk_scale=self.qk_scale, + drop=self.drop_rate, + attn_drop=self.attn_drop_rate, + drop_path=dpr[ + sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1]) + ], + norm_layer=self.norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + norm_before_mlp=self.norm_before_mlp, + ) + self.layers.append(layer) + + self.norm = self.norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.maxpool = nn.AdaptiveMaxPool1d(1) + + SF = ( + self.spec_size + // (2 ** (len(self.depths) - 1)) + // self.patch_stride[0] + // self.freq_ratio + ) + self.tscam_conv = nn.Conv2d( + in_channels=self.num_features, + out_channels=self.num_classes, + kernel_size=(SF, 3), + padding=(0, 1), + ) + self.head = nn.Linear(num_classes, num_classes) + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] + ): + self.mel_conv1d = nn.Sequential( + nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), + nn.BatchNorm1d(64), + ) + if self.fusion_type == "daf_1d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_1d": + self.fusion_model = AFF(channels=64, type="1D") + elif self.fusion_type == "iaff_1d": + self.fusion_model = iAFF(channels=64, type="1D") + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"absolute_pos_embed"} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"relative_position_bias_table"} + + def forward_features(self, x, longer_idx=None): + # A deprecated optimization for using a hierarchical output from different blocks + + frames_num = x.shape[2] + x = self.patch_embed(x, longer_idx=longer_idx) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + for i, layer in enumerate(self.layers): + x, attn = layer(x) + # for x + x = self.norm(x) + B, N, C = x.shape + SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] + ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1] + x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST) + B, C, F, T = x.shape + # group 2D CNN + c_freq_bin = F // self.freq_ratio + x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T) + x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1) + # get latent_output + fine_grained_latent_output = torch.mean(x, dim=2) + fine_grained_latent_output = interpolate( + fine_grained_latent_output.permute(0, 2, 1).contiguous(), + 8 * self.patch_stride[1], + ) + + latent_output = self.avgpool(torch.flatten(x, 2)) + latent_output = torch.flatten(latent_output, 1) + + # display the attention map, if needed + + x = self.tscam_conv(x) + x = torch.flatten(x, 2) # B, C, T + + fpx = interpolate( + torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1] + ) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + output_dict = { + "framewise_output": fpx, # already sigmoided + "clipwise_output": torch.sigmoid(x), + "fine_grained_embedding": fine_grained_latent_output, + "embedding": latent_output, + } + + return output_dict + + def crop_wav(self, x, crop_size, spe_pos=None): + time_steps = x.shape[2] + tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device) + for i in range(len(x)): + if spe_pos is None: + crop_pos = random.randint(0, time_steps - crop_size - 1) + else: + crop_pos = spe_pos + tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :] + return tx + + # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model + def reshape_wav2img(self, x): + B, C, T, F = x.shape + target_T = int(self.spec_size * self.freq_ratio) + target_F = self.spec_size // self.freq_ratio + assert ( + T <= target_T and F <= target_F + ), "the wav size should less than or equal to the swin input size" + # to avoid bicubic zero error + if T < target_T: + x = nn.functional.interpolate( + x, (target_T, x.shape[3]), mode="bicubic", align_corners=True + ) + if F < target_F: + x = nn.functional.interpolate( + x, (x.shape[2], target_F), mode="bicubic", align_corners=True + ) + x = x.permute(0, 1, 3, 2).contiguous() + x = x.reshape( + x.shape[0], + x.shape[1], + x.shape[2], + self.freq_ratio, + x.shape[3] // self.freq_ratio, + ) + # print(x.shape) + x = x.permute(0, 1, 3, 2, 4).contiguous() + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4]) + return x + + # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model + def repeat_wat2img(self, x, cur_pos): + B, C, T, F = x.shape + target_T = int(self.spec_size * self.freq_ratio) + target_F = self.spec_size // self.freq_ratio + assert ( + T <= target_T and F <= target_F + ), "the wav size should less than or equal to the swin input size" + # to avoid bicubic zero error + if T < target_T: + x = nn.functional.interpolate( + x, (target_T, x.shape[3]), mode="bicubic", align_corners=True + ) + if F < target_F: + x = nn.functional.interpolate( + x, (x.shape[2], target_F), mode="bicubic", align_corners=True + ) + x = x.permute(0, 1, 3, 2).contiguous() # B C F T + x = x[:, :, :, cur_pos : cur_pos + self.spec_size] + x = x.repeat(repeats=(1, 1, 4, 1)) + return x + + def forward( + self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None + ): # out_feat_keys: List[str] = None): + if self.enable_fusion and x["longer"].sum() == 0: + # if no audio is longer than 10s, then randomly select one audio to be longer + x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True + + if not self.enable_fusion: + x = x["waveform"].to(device=device, non_blocking=True) + x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + if self.training: + x = self.spec_augmenter(x) + + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x) + else: + longer_list = x["longer"].to(device=device, non_blocking=True) + x = x["mel_fusion"].to(device=device, non_blocking=True) + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + longer_list_idx = torch.where(longer_list)[0] + if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: + new_x = x[:, 0:1, :, :].clone().contiguous() + if len(longer_list_idx) > 0: + # local processing + fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous() + FB, FC, FT, FF = fusion_x_local.size() + fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) + fusion_x_local = torch.permute( + fusion_x_local, (0, 2, 1) + ).contiguous() + fusion_x_local = self.mel_conv1d(fusion_x_local) + fusion_x_local = fusion_x_local.view( + FB, FC, FF, fusion_x_local.size(-1) + ) + fusion_x_local = ( + torch.permute(fusion_x_local, (0, 2, 1, 3)) + .contiguous() + .flatten(2) + ) + if fusion_x_local.size(-1) < FT: + fusion_x_local = torch.cat( + [ + fusion_x_local, + torch.zeros( + (FB, FF, FT - fusion_x_local.size(-1)), + device=device, + ), + ], + dim=-1, + ) + else: + fusion_x_local = fusion_x_local[:, :, :FT] + # 1D fusion + new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous() + new_x[longer_list_idx] = self.fusion_model( + new_x[longer_list_idx], fusion_x_local + ) + x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :] + else: + x = new_x + + elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: + x = x # no change + + if self.training: + x = self.spec_augmenter(x) + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x, longer_idx=longer_list_idx) + + # if infer_mode: + # # in infer mode. we need to handle different length audio input + # frame_num = x.shape[2] + # target_T = int(self.spec_size * self.freq_ratio) + # repeat_ratio = math.floor(target_T / frame_num) + # x = x.repeat(repeats=(1,1,repeat_ratio,1)) + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # else: + # if x.shape[2] > self.freq_ratio * self.spec_size: + # if self.training: + # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size) + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # else: + # # Change: Hard code here + # overlap_size = (x.shape[2] - 1) // 4 + # output_dicts = [] + # crop_size = (x.shape[2] - 1) // 2 + # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size): + # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos) + # tx = self.reshape_wav2img(tx) + # output_dicts.append(self.forward_features(tx)) + # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) + # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) + # for d in output_dicts: + # clipwise_output += d["clipwise_output"] + # framewise_output += d["framewise_output"] + # clipwise_output = clipwise_output / len(output_dicts) + # framewise_output = framewise_output / len(output_dicts) + # output_dict = { + # 'framewise_output': framewise_output, + # 'clipwise_output': clipwise_output + # } + # else: # this part is typically used, and most easy one + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # x = self.head(x) + + # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T + + return output_dict + + +def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"): + try: + assert audio_cfg.model_name in [ + "tiny", + "base", + "large", + ], "model name for HTS-AT is wrong!" + if audio_cfg.model_name == "tiny": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + num_classes=audio_cfg.class_num, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + config=audio_cfg, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + elif audio_cfg.model_name == "base": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + num_classes=audio_cfg.class_num, + embed_dim=128, + depths=[2, 2, 12, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + config=audio_cfg, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + elif audio_cfg.model_name == "large": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + num_classes=audio_cfg.class_num, + embed_dim=256, + depths=[2, 2, 12, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + config=audio_cfg, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + + return model + except: + raise RuntimeError( + f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." + ) diff --git a/audioldm2/clap/open_clip/loss.py b/audioldm2/clap/open_clip/loss.py new file mode 100755 index 0000000000000000000000000000000000000000..37faba58f3693d0659512ab1d6e19614fbda0675 --- /dev/null +++ b/audioldm2/clap/open_clip/loss.py @@ -0,0 +1,397 @@ +import torch +import torch.distributed.nn +from torch import distributed as dist, nn as nn +from torch.nn import functional as F +import numpy as np +from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def gather_features( + audio_features, + text_features, + audio_features_mlp=None, + text_features_mlp=None, + local_loss=False, + gather_with_grad=False, + rank=0, + world_size=1, + use_horovod=False, + mlp_loss=False, +): + if use_horovod: + assert hvd is not None, "Please install horovod" + if gather_with_grad: + all_audio_features = hvd.allgather(audio_features) + all_text_features = hvd.allgather(text_features) + if mlp_loss: + all_audio_features_mlp = hvd.allgather(audio_features_mlp) + all_text_features_mlp = hvd.allgather(text_features_mlp) + else: + with torch.no_grad(): + all_audio_features = hvd.allgather(audio_features) + all_text_features = hvd.allgather(text_features) + if mlp_loss: + all_audio_features_mlp = hvd.allgather(audio_features_mlp) + all_text_features_mlp = hvd.allgather(text_features_mlp) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_audio_features = list( + all_audio_features.chunk(world_size, dim=0) + ) + gathered_text_features = list( + all_text_features.chunk(world_size, dim=0) + ) + gathered_audio_features[rank] = audio_features + gathered_text_features[rank] = text_features + all_audio_features = torch.cat(gathered_audio_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + if mlp_loss: + gathered_audio_features_mlp = list( + all_audio_features_mlp.chunk(world_size, dim=0) + ) + gathered_text_features_mlp = list( + all_text_features_mlp.chunk(world_size, dim=0) + ) + gathered_audio_features_mlp[rank] = audio_features_mlp + gathered_text_features_mlp[rank] = text_features_mlp + all_audio_features_mlp = torch.cat( + gathered_audio_features_mlp, dim=0 + ) + all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) + else: + # We gather tensors from all gpus + if gather_with_grad: + all_audio_features = torch.cat( + torch.distributed.nn.all_gather(audio_features), dim=0 + ) + all_text_features = torch.cat( + torch.distributed.nn.all_gather(text_features), dim=0 + ) + if mlp_loss: + all_audio_features_mlp = torch.cat( + torch.distributed.nn.all_gather(audio_features_mlp), dim=0 + ) + all_text_features_mlp = torch.cat( + torch.distributed.nn.all_gather(text_features_mlp), dim=0 + ) + else: + gathered_audio_features = [ + torch.zeros_like(audio_features) for _ in range(world_size) + ] + gathered_text_features = [ + torch.zeros_like(text_features) for _ in range(world_size) + ] + dist.all_gather(gathered_audio_features, audio_features) + dist.all_gather(gathered_text_features, text_features) + if mlp_loss: + gathered_audio_features_mlp = [ + torch.zeros_like(audio_features_mlp) for _ in range(world_size) + ] + gathered_text_features_mlp = [ + torch.zeros_like(text_features_mlp) for _ in range(world_size) + ] + dist.all_gather(gathered_audio_features_mlp, audio_features_mlp) + dist.all_gather(gathered_text_features_mlp, text_features_mlp) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_audio_features[rank] = audio_features + gathered_text_features[rank] = text_features + if mlp_loss: + gathered_audio_features_mlp[rank] = audio_features_mlp + gathered_text_features_mlp[rank] = text_features_mlp + + all_audio_features = torch.cat(gathered_audio_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + if mlp_loss: + all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0) + all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) + if mlp_loss: + return ( + all_audio_features, + all_text_features, + all_audio_features_mlp, + all_text_features_mlp, + ) + else: + return all_audio_features, all_text_features + + +class ClipLoss(nn.Module): + def __init__( + self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + mlp_loss=False, + weight_loss_kappa=0, + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.use_horovod = use_horovod + self.mlp_loss = mlp_loss + self.weighted_loss = bool(weight_loss_kappa != 0) + self.weight_loss_kappa = weight_loss_kappa + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def forward( + self, + audio_features, + text_features, + logit_scale_a, + logit_scale_t=None, + audio_features_mlp=None, + text_features_mlp=None, + ): + device = audio_features.device + if self.mlp_loss: + if self.world_size > 1: + ( + all_audio_features, + all_text_features, + all_audio_features_mlp, + all_text_features_mlp, + ) = gather_features( + audio_features=audio_features, + text_features=text_features, + audio_features_mlp=audio_features_mlp, + text_features_mlp=text_features_mlp, + local_loss=self.local_loss, + gather_with_grad=self.gather_with_grad, + rank=self.rank, + world_size=self.world_size, + use_horovod=self.use_horovod, + mlp_loss=self.mlp_loss, + ) + if self.local_loss: + a_logits_per_audio = ( + logit_scale_a * audio_features @ all_text_features_mlp.T + ) + a_logits_per_text = ( + logit_scale_a * text_features_mlp @ all_audio_features.T + ) + t_logits_per_audio = ( + logit_scale_t * audio_features_mlp @ all_text_features.T + ) + t_logits_per_text = ( + logit_scale_t * text_features @ all_audio_features_mlp.T + ) + else: + a_logits_per_audio = ( + logit_scale_a * all_audio_features @ all_text_features_mlp.T + ) + a_logits_per_text = a_logits_per_audio.T + t_logits_per_audio = ( + logit_scale_t * all_audio_features_mlp @ all_text_features.T + ) + t_logits_per_text = t_logits_per_audio.T + else: + a_logits_per_audio = ( + logit_scale_a * audio_features @ text_features_mlp.T + ) + a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T + t_logits_per_audio = ( + logit_scale_t * audio_features_mlp @ text_features.T + ) + t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T + + # calculated ground-truth and cache if enabled + num_logits = a_logits_per_audio.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + + if not self.weighted_loss: + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels) + + F.cross_entropy(a_logits_per_text, labels) + + F.cross_entropy(t_logits_per_audio, labels) + + F.cross_entropy(t_logits_per_text, labels) + ) / 4 + else: + audio_weight = (audio_features @ audio_features.T).detach() + audio_weight = ( + torch.exp( + torch.sum(audio_weight, axis=1) + / (self.weight_loss_kappa * len(audio_weight)) + ) + ).detach() + text_weight = (text_features @ text_features.T).detach() + text_weight = ( + torch.exp( + torch.sum(text_weight, axis=1) + / (self.weight_loss_kappa * len(text_features)) + ) + ).detach() + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight) + + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight) + + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight) + + F.cross_entropy(t_logits_per_text, labels, weight=text_weight) + ) / 4 + else: + if self.world_size > 1: + all_audio_features, all_text_features = gather_features( + audio_features=audio_features, + text_features=text_features, + local_loss=self.local_loss, + gather_with_grad=self.gather_with_grad, + rank=self.rank, + world_size=self.world_size, + use_horovod=self.use_horovod, + mlp_loss=self.mlp_loss, + ) + + if self.local_loss: + logits_per_audio = ( + logit_scale_a * audio_features @ all_text_features.T + ) + logits_per_text = ( + logit_scale_a * text_features @ all_audio_features.T + ) + else: + logits_per_audio = ( + logit_scale_a * all_audio_features @ all_text_features.T + ) + logits_per_text = logits_per_audio.T + else: + logits_per_audio = logit_scale_a * audio_features @ text_features.T + logits_per_text = logit_scale_a * text_features @ audio_features.T + + # calculated ground-truth and cache if enabled + num_logits = logits_per_audio.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + if not self.weighted_loss: + total_loss = ( + F.cross_entropy(logits_per_audio, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + else: + audio_weight = (all_audio_features @ all_audio_features.T).detach() + audio_weight = ( + torch.exp( + torch.sum(audio_weight, axis=1) + / (self.weight_loss_kappa * len(all_audio_features)) + ) + ).detach() + text_weight = (all_text_features @ all_text_features.T).detach() + text_weight = ( + torch.exp( + torch.sum(text_weight, axis=1) + / (self.weight_loss_kappa * len(all_text_features)) + ) + ).detach() + total_loss = ( + F.cross_entropy(logits_per_audio, labels, weight=text_weight) + + F.cross_entropy(logits_per_text, labels, weight=audio_weight) + ) / 2 + return total_loss + + +def lp_gather_features(pred, target, world_size=1, use_horovod=False): + if use_horovod: + assert hvd is not None, "Please install horovod" + with torch.no_grad(): + all_preds = hvd.allgather(pred) + all_targets = hvd.allgath(target) + else: + gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)] + gathered_targets = [torch.zeros_like(target) for _ in range(world_size)] + + dist.all_gather(gathered_preds, pred) + dist.all_gather(gathered_targets, target) + all_preds = torch.cat(gathered_preds, dim=0) + all_targets = torch.cat(gathered_targets, dim=0) + + return all_preds, all_targets + + +def get_map(pred, target): + pred = torch.sigmoid(pred).numpy() + target = target.numpy() + return np.mean(average_precision_score(target, pred, average=None)) + + +def get_acc(pred, target): + pred = torch.argmax(pred, 1).numpy() + target = torch.argmax(target, 1).numpy() + return accuracy_score(target, pred) + + +def get_mauc(pred, target): + pred = torch.sigmoid(pred).numpy() + target = target.numpy() + return np.mean(roc_auc_score(target, pred, average=None)) + + +class LPMetrics(object): + def __init__(self, metric_names=["map", "acc", "mauc"]): + self.metrics = [] + for name in metric_names: + self.metrics.append(self.get_metric(name)) + self.metric_names = metric_names + + def get_metric(self, name): + if name == "map": + return get_map + elif name == "acc": + return get_acc + elif name == "mauc": + return get_mauc + else: + raise ValueError(f"the metric should be at least one of [map, acc, mauc]") + + def evaluate_mertics(self, pred, target): + metric_dict = {} + for i in range(len(self.metric_names)): + metric_dict[self.metric_names[i]] = self.metrics[i](pred, target) + return metric_dict + + +def calc_celoss(pred, target): + target = torch.argmax(target, 1).long() + return nn.CrossEntropyLoss()(pred, target) + + +class LPLoss(nn.Module): + def __init__(self, loss_name): + super().__init__() + if loss_name == "bce": + self.loss_func = nn.BCEWithLogitsLoss() + elif loss_name == "ce": + self.loss_func = calc_celoss + elif loss_name == "mse": + self.loss_func = nn.MSELoss() + else: + raise ValueError(f"the loss func should be at least one of [bce, ce, mse]") + + def forward(self, pred, target): + loss = self.loss_func(pred, target) + return loss diff --git a/audioldm2/clap/open_clip/model.py b/audioldm2/clap/open_clip/model.py new file mode 100755 index 0000000000000000000000000000000000000000..130fb582d016868d478e2d10e90d7fc0e7999078 --- /dev/null +++ b/audioldm2/clap/open_clip/model.py @@ -0,0 +1,931 @@ +""" CLAP Model + +Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +Adapted to the Audio Task. +""" + +from collections import OrderedDict +from dataclasses import dataclass +from typing import Tuple, Union, Callable, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +import logging +from .utils import freeze_batch_norm_2d + +from .pann_model import create_pann_model +from .htsat import create_htsat_model +from transformers import BertModel, RobertaModel, BartModel, RobertaConfig + + +class MLPLayers(nn.Module): + def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1): + super(MLPLayers, self).__init__() + self.nonlin = nonlin + self.dropout = dropout + + sequence = [] + for u0, u1 in zip(units[:-1], units[1:]): + sequence.append(nn.Linear(u0, u1)) + sequence.append(self.nonlin) + sequence.append(nn.Dropout(self.dropout)) + sequence = sequence[:-2] + + self.sequential = nn.Sequential(*sequence) + + def forward(self, X): + X = self.sequential(X) + return X + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ( + "0", + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False, + ), + ), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__( + self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None + ): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 + ) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( + 2, 0, 1 + ) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] + ), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert ( + unlocked_groups == 0 + ), "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + def stem(self, x): + for conv, bn in [ + (self.conv1, self.bn1), + (self.conv2, self.bn2), + (self.conv3, self.bn3), + ]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + + def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock(width, heads, act_layer=act_layer) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + x = r(x, attn_mask=attn_mask) + return x + + +class VisualTransformer(nn.Module): + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + act_layer: Callable = nn.GELU, + ): + super().__init__() + self.image_size = image_size + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn((image_size // patch_size) ** 2 + 1, width) + ) + self.ln_pre = LayerNorm(width) + + self.text_branch = Transformer(width, layers, heads, act_layer=act_layer) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert ( + unlocked_groups == 0 + ), "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [ + self.class_embedding.to(x.dtype) + + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device + ), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_branch(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +@dataclass +class CLAPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + timm_model_name: str = ( + None # a valid model name overrides layers, width, patch_size + ) + timm_model_pretrained: bool = ( + False # use (imagenet) pretrained weights for named model + ) + timm_pool: str = ( + "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + ) + timm_proj: str = ( + "linear" # linear projection for timm model output ('linear', 'mlp', '') + ) + + +# Audio Config Class +@dataclass +class CLAPAudioCfp: + model_type: str = "PANN" + model_name: str = "Cnn14" + sample_rate: int = 48000 + # Param + audio_length: int = 1024 + window_size: int = 1024 + hop_size: int = 1024 + fmin: int = 50 + fmax: int = 14000 + class_num: int = 527 + mel_bins: int = 64 + clip_samples: int = 480000 + + +@dataclass +class CLAPTextCfg: + context_length: int + vocab_size: int + width: int + heads: int + layers: int + model_type: str + + +class CLAP(nn.Module): + def __init__( + self, + embed_dim: int, + audio_cfg: CLAPAudioCfp, + text_cfg: CLAPTextCfg, + quick_gelu: bool = False, + enable_fusion: bool = False, + fusion_type: str = "None", + joint_embed_shape: int = 512, + mlp_act: str = "relu", + ): + super().__init__() + if isinstance(audio_cfg, dict): + audio_cfg = CLAPAudioCfp(**audio_cfg) + if isinstance(text_cfg, dict): + text_cfg = CLAPTextCfg(**text_cfg) + + self.audio_cfg = audio_cfg + self.text_cfg = text_cfg + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + self.joint_embed_shape = joint_embed_shape + self.mlp_act = mlp_act + + self.context_length = text_cfg.context_length + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if mlp_act == "relu": + mlp_act_layer = nn.ReLU() + elif mlp_act == "gelu": + mlp_act_layer = nn.GELU() + else: + raise NotImplementedError + + # audio branch + # audio branch parameters + if audio_cfg.model_type == "PANN": + self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type) + elif audio_cfg.model_type == "HTSAT": + self.audio_branch = create_htsat_model( + audio_cfg, enable_fusion, fusion_type + ) + else: + logging.error(f"Model config for {audio_cfg.model_type} not found") + raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.") + + # text branch + # text branch parameters + if text_cfg.model_type == "transformer": + self.text_branch = Transformer( + width=text_cfg.width, + layers=text_cfg.layers, + heads=text_cfg.heads, + act_layer=act_layer, + ) + self.vocab_size = text_cfg.vocab_size + self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, text_cfg.width) + ) + self.ln_final = LayerNorm(text_cfg.width) + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(text_cfg.width, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "bert": + self.text_branch = BertModel.from_pretrained("bert-base-uncased") + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "roberta": + self.text_branch = RobertaModel( + RobertaConfig.from_pretrained("roberta-base") + ) + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "bart": + self.text_branch = BartModel.from_pretrained("facebook/bart-base") + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + else: + logging.error(f"Model config for {text_cfg.model_type} not found") + raise RuntimeError(f"Model config for {text_cfg.model_type} not found.") + self.text_branch_type = text_cfg.model_type + # text branch parameters + + # audio branch parameters + self.audio_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + + # below here is text branch parameters + + # ============================================================================================================ + self.audio_projection = nn.Sequential( + nn.Linear(embed_dim, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + + self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False) + + self.init_text_branch_parameters() + + def init_text_branch_parameters(self): + if self.text_branch_type == "transformer": + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + proj_std = (self.text_branch.width**-0.5) * ( + (2 * self.text_branch.layers) ** -0.5 + ) + attn_std = self.text_branch.width**-0.5 + fc_std = (2 * self.text_branch.width) ** -0.5 + for block in self.text_branch.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + if self.text_branch_type == "bert" or self.text_branch_type == "roberta": + self.text_branch.embeddings.word_embeddings.weight.shape[-1] + elif self.text_branch_type == "bart": + self.text_branch.shared.weight.shape[-1] + else: + self.text_branch.width + nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07)) + nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07)) + + # deprecated + # if hasattr(self.visual, 'init_parameters'): + # self.visual.init_parameters() + + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def encode_audio(self, audio, device): + return self.audio_branch( + audio, mixup_lambda=None, device=device + ) # mix lambda needs to add + + # def list_of_dict_of_tensor2dict_of_tensor(self, x, device): + # tmp = {} + # for k in x[0].keys(): + # tmp[k] = [] + # for i in range(len(x)): + # tmp[k].append(x[i][k][:77]) + # for k in x[0].keys(): + # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True) + # return tmp + + def encode_text(self, text, device): + if self.text_branch_type == "transformer": + text = text.to(device=device, non_blocking=True) + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_branch(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)]) + elif self.text_branch_type == "bert": + # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device) + # text = BatchEncoding(text) + x = self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + token_type_ids=text["token_type_ids"].to( + device=device, non_blocking=True + ), + )["pooler_output"] + x = self.text_projection(x) + elif self.text_branch_type == "roberta": + x = self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + )["pooler_output"] + x = self.text_projection(x) + elif self.text_branch_type == "bart": + x = torch.mean( + self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + )["encoder_last_hidden_state"], + axis=1, + ) + x = self.text_projection(x) + else: + logging.error(f"Model type {self.text_branch_type} not found") + raise RuntimeError(f"Model type {self.text_branch_type} not found.") + return x + + def forward(self, audio, text, device=None): + """Forward audio and text into the CLAP + + Parameters + ---------- + audio: torch.Tensor (batch_size, audio_length) + the time-domain audio input / the batch of mel_spec and longer list. + text: torch.Tensor () // need to add + the text token input + """ + if device is None: + if audio is not None: + device = audio.device + elif text is not None: + device = text.device + if audio is None and text is None: + # a hack to get the logit scale + return self.logit_scale_a.exp(), self.logit_scale_t.exp() + elif audio is None: + return self.encode_text(text, device=device) + elif text is None: + return self.audio_projection( + self.encode_audio(audio, device=device)["embedding"] + ) + audio_features = self.audio_projection( + self.encode_audio(audio, device=device)["embedding"] + ) + audio_features = F.normalize(audio_features, dim=-1) + + text_features = self.encode_text(text, device=device) + # print("text_features", text_features) + # print("text_features.shape", text_features.shape) + # print("text_features.type", type(text_features)) + text_features = F.normalize(text_features, dim=-1) + + audio_features_mlp = self.audio_transform(audio_features) + text_features_mlp = self.text_transform(text_features) + # Four outputs: audio features (basic & MLP), text features (basic & MLP) + return ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + self.logit_scale_a.exp(), + self.logit_scale_t.exp(), + ) + + def get_logit_scale(self): + return self.logit_scale_a.exp(), self.logit_scale_t.exp() + + def get_text_embedding(self, data): + """Get the text embedding from the model + + Parameters + ---------- + data: torch.Tensor + a tensor of text embedding + + Returns + ---------- + text_embed: torch.Tensor + a tensor of text_embeds (N, D) + + """ + device = next(self.parameters()).device + for k in data: + data[k] = data[k].to(device) + text_embeds = self.encode_text(data, device=device) + text_embeds = F.normalize(text_embeds, dim=-1) + + return text_embeds + + def get_audio_embedding(self, data): + """Get the audio embedding from the model + + Parameters + ---------- + data: a list of dict + the audio input dict list from 'get_audio_feature' method + + Returns + ---------- + audio_embed: torch.Tensor + a tensor of audio_embeds (N, D) + + """ + device = next(self.parameters()).device + # input_dict = {} + # keys = data[0].keys() + # for k in keys: + # input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to( + # device + # ) + audio_embeds = self.audio_projection( + self.encode_audio(data, device=device)["embedding"] + ) + audio_embeds = F.normalize(audio_embeds, dim=-1) + + return audio_embeds + + def audio_infer(self, audio, hopsize=None, device=None): + """Forward one audio and produce the audio embedding + + Parameters + ---------- + audio: (audio_length) + the time-domain audio input, notice that it must be only one input + hopsize: int + the overlap hopsize as the sliding window + + Returns + ---------- + output_dict: { + key: [n, (embedding_shape)] if "HTS-AT" + or + key: [(embedding_shape)] if "PANN" + } + the list of key values of the audio branch + + """ + + assert not self.training, "the inference mode must be run at eval stage" + output_dict = {} + # PANN + if self.audio_cfg.model_type == "PANN": + audio_input = audio.unsqueeze(dim=0) + output_dict[key] = self.encode_audio(audio_input, device=device)[ + key + ].squeeze(dim=0) + elif self.audio_cfg.model_type == "HTSAT": + # repeat + audio_len = len(audio) + k = self.audio_cfg.clip_samples // audio_len + if k > 1: + audio = audio.repeat(k) + audio_len = len(audio) + + if hopsize is None: + hopsize = min(hopsize, audio_len) + + if audio_len > self.audio_cfg.clip_samples: + audio_input = [ + audio[pos : pos + self.audio_cfg.clip_samples].clone() + for pos in range( + 0, audio_len - self.audio_cfg.clip_samples, hopsize + ) + ] + audio_input.append(audio[-self.audio_cfg.clip_samples :].clone()) + audio_input = torch.stack(audio_input) + output_dict[key] = self.encode_audio(audio_input, device=device)[key] + else: + audio_input = audio.unsqueeze(dim=0) + output_dict[key] = self.encode_audio(audio_input, device=device)[ + key + ].squeeze(dim=0) + + return output_dict + + +def convert_weights_to_fp16(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [ + *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], + "in_proj_bias", + "bias_k", + "bias_v", + ]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +# Ignore the state dict of the vision part +def build_model_from_openai_state_dict( + state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None" +): + embed_dim = model_cfg["embed_dim"] + audio_cfg = model_cfg["audio_cfg"] + text_cfg = model_cfg["text_cfg"] + state_dict["positional_embedding"].shape[0] + state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_width // 64 + transformer_layers = len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith(f"transformer.resblocks") + ) + ) + + audio_cfg = CLAPAudioCfp(**audio_cfg) + text_cfg = CLAPTextCfg(**text_cfg) + + model = CLAP( + embed_dim, + audio_cfg=audio_cfg, + text_cfg=text_cfg, + quick_gelu=True, # OpenAI models were trained with QuickGELU + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + state_dict["logit_scale_a"] = state_dict["logit_scale"] + state_dict["logit_scale_t"] = state_dict["logit_scale"] + pop_keys = list(state_dict.keys())[::] + # pop the visual branch saved weights + for key in pop_keys: + if key.startswith("visual."): + state_dict.pop(key, None) + + for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + # not use fp16 + # convert_weights_to_fp16(model) + model.load_state_dict(state_dict, strict=False) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device("cpu")): + model.eval() + audio_length = model.audio_cfg.audio_length + example_audio = torch.ones((batch_size, audio_length), device=device) + example_text = torch.zeros( + (batch_size, model.context_length), dtype=torch.int, device=device + ) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_audio, example_text), + encode_text=(example_text,), + encode_image=(example_audio,), + ), + ) + model.audio_cfg.audio_length = audio_length # Question: what does this do? + return model diff --git a/audioldm2/clap/open_clip/model_configs/HTSAT-base.json b/audioldm2/clap/open_clip/model_configs/HTSAT-base.json new file mode 100755 index 0000000000000000000000000000000000000000..6cef625a89daf4431f1c9f72e10bc9640eef2ba8 --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/HTSAT-base.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 1024, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "base" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/HTSAT-large.json b/audioldm2/clap/open_clip/model_configs/HTSAT-large.json new file mode 100755 index 0000000000000000000000000000000000000000..699cdb1b16855582606551e4196b24aba2ffd871 --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/HTSAT-large.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "large" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json b/audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json new file mode 100755 index 0000000000000000000000000000000000000000..73e42990fe8361a0df502e7f93d29f19f58c9ecb --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 768, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1536, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "tiny" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json b/audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json new file mode 100755 index 0000000000000000000000000000000000000000..a6e7821163d9afa81c27345a1e472475b92af169 --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 768, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "tiny" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/PANN-10.json b/audioldm2/clap/open_clip/model_configs/PANN-10.json new file mode 100755 index 0000000000000000000000000000000000000000..954ddf62921aed7dde9c37ffffec98a2e96a4ee7 --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/PANN-10.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 1024, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn10" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json b/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json new file mode 100755 index 0000000000000000000000000000000000000000..b7989bc0cd95d0d39049b7524eba508b3e386439 --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 18000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json b/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json new file mode 100755 index 0000000000000000000000000000000000000000..56bdb56bedc304ffa52d8bf5988cea2c1d82d14e --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 960000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 360, + "fmin": 50, + "fmax": 8000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json b/audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json new file mode 100755 index 0000000000000000000000000000000000000000..5756e3bebc97cc985f512cb081930fee4e49bec1 --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 4 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json b/audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json new file mode 100755 index 0000000000000000000000000000000000000000..5a9e7e208b661619d5e26625e849da1adda8a475 --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1536, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14.json b/audioldm2/clap/open_clip/model_configs/PANN-14.json new file mode 100755 index 0000000000000000000000000000000000000000..39a5134cde1d8c50f4758377c952ef22f07bab41 --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/PANN-14.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/PANN-6.json b/audioldm2/clap/open_clip/model_configs/PANN-6.json new file mode 100755 index 0000000000000000000000000000000000000000..21ebc344326de260c386ba77e0ad63cf9b04febf --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/PANN-6.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 512, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn6" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json b/audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json new file mode 100755 index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624 --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/RN101.json b/audioldm2/clap/open_clip/model_configs/RN101.json new file mode 100755 index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32 --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/RN101.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json b/audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json new file mode 100755 index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1 --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/audioldm2/clap/open_clip/model_configs/RN50.json b/audioldm2/clap/open_clip/model_configs/RN50.json new file mode 100755 index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2 --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/RN50.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/RN50x16.json b/audioldm2/clap/open_clip/model_configs/RN50x16.json new file mode 100755 index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/RN50x16.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/RN50x4.json b/audioldm2/clap/open_clip/model_configs/RN50x4.json new file mode 100755 index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/RN50x4.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/ViT-B-16.json b/audioldm2/clap/open_clip/model_configs/ViT-B-16.json new file mode 100755 index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/ViT-B-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json b/audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json new file mode 100755 index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5 --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/ViT-B-32.json b/audioldm2/clap/open_clip/model_configs/ViT-B-32.json new file mode 100755 index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/ViT-L-14.json b/audioldm2/clap/open_clip/model_configs/ViT-L-14.json new file mode 100755 index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241 --- /dev/null +++ b/audioldm2/clap/open_clip/model_configs/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/openai.py b/audioldm2/clap/open_clip/openai.py new file mode 100755 index 0000000000000000000000000000000000000000..3f4eb8b55fe960e1792b3da804b60b3d8f70fe26 --- /dev/null +++ b/audioldm2/clap/open_clip/openai.py @@ -0,0 +1,156 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import Union, List + +import torch + +from .model import build_model_from_openai_state_dict +from .pretrained import ( + get_pretrained_url, + list_pretrained_tag_models, + download_pretrained, +) + +__all__ = ["list_openai_models", "load_openai_model"] + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_tag_models("openai") + + +def load_openai_model( + name: str, + model_cfg, + device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + jit=True, + cache_dir=os.path.expanduser("~/.cache/clip"), + enable_fusion: bool = False, + fusion_type: str = "None", +): + """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + + Returns + ------- + model : torch.nn.Module + The CLAP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if get_pretrained_url(name, "openai"): + model_path = download_pretrained( + get_pretrained_url(name, "openai"), root=cache_dir + ) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError( + f"Model {name} not found; available models = {list_openai_models()}" + ) + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn( + f"File {model_path} is not a JIT archive. Loading as a state dict instead" + ) + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + try: + model = build_model_from_openai_state_dict( + state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type + ).to(device) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict( + sd, model_cfg, enable_fusion, fusion_type + ).to(device) + + if str(device) == "cpu": + model.float() + return model + + # patch the device names + device_holder = torch.jit.trace( + lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] + ) + device_node = [ + n + for n in device_holder.graph.findAllNodes("prim::Constant") + if "Device" in repr(n) + ][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith( + "cuda" + ): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_audio) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace( + lambda: torch.ones([]).float(), example_inputs=[] + ) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [ + 1, + 2, + ]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_audio) + patch_float(model.encode_text) + model.float() + + model.audio_branch.audio_length = model.audio_cfg.audio_length + return model diff --git a/audioldm2/clap/open_clip/pann_model.py b/audioldm2/clap/open_clip/pann_model.py new file mode 100755 index 0000000000000000000000000000000000000000..e9fab8e03cdca370c141a9e321e98d256e79fb27 --- /dev/null +++ b/audioldm2/clap/open_clip/pann_model.py @@ -0,0 +1,697 @@ +# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition +# Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn +# Some layers are re-designed for CLAP +import os + +os.environ["NUMBA_CACHE_DIR"] = "/tmp/" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from .utils import do_mixup, interpolate +from .feature_fusion import iAFF, AFF, DAF + + +def init_layer(layer): + """Initialize a Linear or Convolutional layer.""" + nn.init.xavier_uniform_(layer.weight) + + if hasattr(layer, "bias"): + if layer.bias is not None: + layer.bias.data.fill_(0.0) + + +def init_bn(bn): + """Initialize a Batchnorm layer.""" + bn.bias.data.fill_(0.0) + bn.weight.data.fill_(1.0) + + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(ConvBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ) + + self.conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ) + + self.bn1 = nn.BatchNorm2d(out_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_layer(self.conv2) + init_bn(self.bn1) + init_bn(self.bn2) + + def forward(self, input, pool_size=(2, 2), pool_type="avg"): + x = input + x = F.relu_(self.bn1(self.conv1(x))) + x = F.relu_(self.bn2(self.conv2(x))) + if pool_type == "max": + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg": + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg+max": + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception("Incorrect argument!") + + return x + + +class ConvBlock5x5(nn.Module): + def __init__(self, in_channels, out_channels): + super(ConvBlock5x5, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(5, 5), + stride=(1, 1), + padding=(2, 2), + bias=False, + ) + + self.bn1 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_bn(self.bn1) + + def forward(self, input, pool_size=(2, 2), pool_type="avg"): + x = input + x = F.relu_(self.bn1(self.conv1(x))) + if pool_type == "max": + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg": + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg+max": + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception("Incorrect argument!") + + return x + + +class AttBlock(nn.Module): + def __init__(self, n_in, n_out, activation="linear", temperature=1.0): + super(AttBlock, self).__init__() + + self.activation = activation + self.temperature = temperature + self.att = nn.Conv1d( + in_channels=n_in, + out_channels=n_out, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + self.cla = nn.Conv1d( + in_channels=n_in, + out_channels=n_out, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + + self.bn_att = nn.BatchNorm1d(n_out) + self.init_weights() + + def init_weights(self): + init_layer(self.att) + init_layer(self.cla) + init_bn(self.bn_att) + + def forward(self, x): + # x: (n_samples, n_in, n_time) + norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) + cla = self.nonlinear_transform(self.cla(x)) + x = torch.sum(norm_att * cla, dim=2) + return x, norm_att, cla + + def nonlinear_transform(self, x): + if self.activation == "linear": + return x + elif self.activation == "sigmoid": + return torch.sigmoid(x) + + +class Cnn14(nn.Module): + def __init__( + self, + sample_rate, + window_size, + hop_size, + mel_bins, + fmin, + fmax, + classes_num, + enable_fusion=False, + fusion_type="None", + ): + super(Cnn14, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + if (self.enable_fusion) and (self.fusion_type == "channel_map"): + self.conv_block1 = ConvBlock(in_channels=4, out_channels=64) + else: + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] + ): + self.mel_conv1d = nn.Sequential( + nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), + nn.BatchNorm1d(64), # No Relu + ) + if self.fusion_type == "daf_1d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_1d": + self.fusion_model = AFF(channels=64, type="1D") + elif self.fusion_type == "iaff_1d": + self.fusion_model = iAFF(channels=64, type="1D") + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + self.mel_conv2d = nn.Sequential( + nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + ) + + if self.fusion_type == "daf_2d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_2d": + self.fusion_model = AFF(channels=64, type="2D") + elif self.fusion_type == "iaff_2d": + self.fusion_model = iAFF(channels=64, type="2D") + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + if self.enable_fusion and input["longer"].sum() == 0: + # if no audio is longer than 10s, then randomly select one audio to be longer + input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True + + if not self.enable_fusion: + x = self.spectrogram_extractor( + input["waveform"].to(device=device, non_blocking=True) + ) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + else: + longer_list = input["longer"].to(device=device, non_blocking=True) + x = input["mel_fusion"].to(device=device, non_blocking=True) + longer_list_idx = torch.where(longer_list)[0] + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: + new_x = x[:, 0:1, :, :].clone().contiguous() + # local processing + if len(longer_list_idx) > 0: + fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous() + FB, FC, FT, FF = fusion_x_local.size() + fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) + fusion_x_local = torch.permute( + fusion_x_local, (0, 2, 1) + ).contiguous() + fusion_x_local = self.mel_conv1d(fusion_x_local) + fusion_x_local = fusion_x_local.view( + FB, FC, FF, fusion_x_local.size(-1) + ) + fusion_x_local = ( + torch.permute(fusion_x_local, (0, 2, 1, 3)) + .contiguous() + .flatten(2) + ) + if fusion_x_local.size(-1) < FT: + fusion_x_local = torch.cat( + [ + fusion_x_local, + torch.zeros( + (FB, FF, FT - fusion_x_local.size(-1)), + device=device, + ), + ], + dim=-1, + ) + else: + fusion_x_local = fusion_x_local[:, :, :FT] + # 1D fusion + new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous() + new_x[longer_list_idx] = self.fusion_model( + new_x[longer_list_idx], fusion_x_local + ) + x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :] + else: + x = new_x + elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: + x = x # no change + + if self.training: + x = self.spec_augmenter(x) + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + global_x = x[:, 0:1, :, :] + + # global processing + B, C, H, W = global_x.shape + global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg") + if len(longer_list_idx) > 0: + local_x = x[longer_list_idx, 1:, :, :].contiguous() + TH = global_x.size(-2) + # local processing + B, C, H, W = local_x.shape + local_x = local_x.view(B * C, 1, H, W) + local_x = self.mel_conv2d(local_x) + local_x = local_x.view( + B, C, local_x.size(1), local_x.size(2), local_x.size(3) + ) + local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3) + TB, TC, _, TW = local_x.size() + if local_x.size(-2) < TH: + local_x = torch.cat( + [ + local_x, + torch.zeros( + (TB, TC, TH - local_x.size(-2), TW), + device=global_x.device, + ), + ], + dim=-2, + ) + else: + local_x = local_x[:, :, :TH, :] + + global_x[longer_list_idx] = self.fusion_model( + global_x[longer_list_idx], local_x + ) + x = global_x + else: + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 32) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = { + "clipwise_output": clipwise_output, + "embedding": embedding, + "fine_grained_embedding": latent_output, + } + return output_dict + + +class Cnn6(nn.Module): + def __init__( + self, + sample_rate, + window_size, + hop_size, + mel_bins, + fmin, + fmax, + classes_num, + enable_fusion=False, + fusion_type="None", + ): + super(Cnn6, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512) + + self.fc1 = nn.Linear(512, 512, bias=True) + self.fc_audioset = nn.Linear(512, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 16) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = { + "clipwise_output": clipwise_output, + "embedding": embedding, + "fine_grained_embedding": latent_output, + } + + return output_dict + + +class Cnn10(nn.Module): + def __init__( + self, + sample_rate, + window_size, + hop_size, + mel_bins, + fmin, + fmax, + classes_num, + enable_fusion=False, + fusion_type="None", + ): + super(Cnn10, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + + self.fc1 = nn.Linear(1024, 1024, bias=True) + self.fc_audioset = nn.Linear(1024, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 32) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = { + "clipwise_output": clipwise_output, + "embedding": embedding, + "fine_grained_embedding": latent_output, + } + + return output_dict + + +def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"): + try: + ModelProto = eval(audio_cfg.model_name) + model = ModelProto( + sample_rate=audio_cfg.sample_rate, + window_size=audio_cfg.window_size, + hop_size=audio_cfg.hop_size, + mel_bins=audio_cfg.mel_bins, + fmin=audio_cfg.fmin, + fmax=audio_cfg.fmax, + classes_num=audio_cfg.class_num, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + return model + except: + raise RuntimeError( + f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." + ) diff --git a/audioldm2/clap/open_clip/pretrained.py b/audioldm2/clap/open_clip/pretrained.py new file mode 100755 index 0000000000000000000000000000000000000000..e211d8b5b59320a599e62605f1dee6199f317253 --- /dev/null +++ b/audioldm2/clap/open_clip/pretrained.py @@ -0,0 +1,167 @@ +import hashlib +import os +import urllib +import warnings + +from tqdm import tqdm + +_RN50 = dict( + openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", +) + +_RN50_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", +) + +_RN101 = dict( + openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", +) + +_RN101_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", +) + +_RN50x4 = dict( + openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", +) + +_RN50x16 = dict( + openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", +) + +_RN50x64 = dict( + openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", +) + +_VITB32 = dict( + openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", +) + +_VITB32_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", +) + +_VITB16 = dict( + openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", +) + +_VITL14 = dict( + openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", +) + +_PRETRAINED = { + "RN50": _RN50, + "RN50-quickgelu": _RN50_quickgelu, + "RN101": _RN101, + "RN101-quickgelu": _RN101_quickgelu, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "ViT-B-32": _VITB32, + "ViT-B-32-quickgelu": _VITB32_quickgelu, + "ViT-B-16": _VITB16, + "ViT-L-14": _VITL14, +} + + +def list_pretrained(as_str: bool = False): + """returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [ + ":".join([k, t]) if as_str else (k, t) + for k in _PRETRAINED.keys() + for t in _PRETRAINED[k].keys() + ] + + +def list_pretrained_tag_models(tag: str): + """return all models having the specified pretrain tag""" + models = [] + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_model_tags(model: str): + """return all pretrain tags for the specified model architecture""" + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def get_pretrained_url(model: str, tag: str): + if model not in _PRETRAINED: + return "" + model_pretrained = _PRETRAINED[model] + if tag not in model_pretrained: + return "" + return model_pretrained[tag] + + +def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + if "openaipublic" in url: + expected_sha256 = url.split("/")[-2] + else: + expected_sha256 = "" + + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if ( + hashlib.sha256(open(download_target, "rb").read()).hexdigest() + == expected_sha256 + ): + return download_target + else: + warnings.warn( + f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" + ) + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if ( + expected_sha256 + and hashlib.sha256(open(download_target, "rb").read()).hexdigest() + != expected_sha256 + ): + raise RuntimeError( + f"Model has been downloaded but the SHA256 checksum does not not match" + ) + + return download_target diff --git a/audioldm2/clap/open_clip/timm_model.py b/audioldm2/clap/open_clip/timm_model.py new file mode 100755 index 0000000000000000000000000000000000000000..b8486b9e62580bb65f0f50a0a7000890cb7ee42d --- /dev/null +++ b/audioldm2/clap/open_clip/timm_model.py @@ -0,0 +1,112 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +from collections import OrderedDict + +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import ( + AttentionPool2d as AbsAttentionPool2d, + ) +except ImportError: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """timm model adapter + # FIXME this adapter is a work in progress, may change in ways that break weight compat + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool="avg", + proj="linear", + drop=0.0, + pretrained=False, + ): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + + self.image_size = to_2tuple(image_size) + self.trunk = timm.create_model(model_name, pretrained=pretrained) + feat_size = self.trunk.default_cfg.get("pool_size", None) + feature_ndim = 1 if not feat_size else 2 + if pool in ("abs_attn", "rot_attn"): + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool="") + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + if pool == "abs_attn": + head_layers["pool"] = AbsAttentionPool2d( + prev_chs, feat_size=feat_size, out_features=embed_dim + ) + prev_chs = embed_dim + elif pool == "rot_attn": + head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + else: + assert proj, "projection layer needed if non-attention pooling is used." + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == "linear": + head_layers["drop"] = nn.Dropout(drop) + head_layers["proj"] = nn.Linear(prev_chs, embed_dim) + elif proj == "mlp": + head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`" + ) + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/audioldm2/clap/open_clip/tokenizer.py b/audioldm2/clap/open_clip/tokenizer.py new file mode 100755 index 0000000000000000000000000000000000000000..ee4d28450ec5dd12a79daf38cf3088e9e73c2cd5 --- /dev/null +++ b/audioldm2/clap/open_clip/tokenizer.py @@ -0,0 +1,197 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" + ) + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + if not special_tokens: + special_tokens = ["", ""] + else: + special_tokens = ["", ""] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t: t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile( + special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text + + +_tokenizer = SimpleTokenizer() + + +def tokenize( + texts: Union[str, List[str]], context_length: int = 77 +) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + result[i, : len(tokens)] = torch.tensor(tokens) + + return result diff --git a/audioldm2/clap/open_clip/transform.py b/audioldm2/clap/open_clip/transform.py new file mode 100755 index 0000000000000000000000000000000000000000..77aaa722c4a5544ac50de6df35d3e922f63b111d --- /dev/null +++ b/audioldm2/clap/open_clip/transform.py @@ -0,0 +1,45 @@ +from torchvision.transforms import ( + Normalize, + Compose, + RandomResizedCrop, + InterpolationMode, + ToTensor, + Resize, + CenterCrop, +) + + +def _convert_to_rgb(image): + return image.convert("RGB") + + +def image_transform( + image_size: int, + is_train: bool, + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), +): + normalize = Normalize(mean=mean, std=std) + if is_train: + return Compose( + [ + RandomResizedCrop( + image_size, + scale=(0.9, 1.0), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) + else: + return Compose( + [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) diff --git a/audioldm2/clap/open_clip/utils.py b/audioldm2/clap/open_clip/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..77875569ff4aff81bf9545ce6ec58e0326d49d0c --- /dev/null +++ b/audioldm2/clap/open_clip/utils.py @@ -0,0 +1,356 @@ +import numpy as np +import torch +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d +import logging +import h5py +from tqdm import tqdm +import random +import json +import os +import pathlib + +# TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later. +dataset_split = { + "audiocaps": ["train", "valid", "test"], + "audioset": ["balanced_train", "unbalanced_train", "eval"], + "BBCSoundEffects": ["train", "test"], + "Clotho": ["train", "test", "valid"], + "free_to_use_sounds": ["train", "test"], + "paramount_motion": ["train", "test"], + "sonniss_game_effects": ["train", "test"], + "wesoundeffects": ["train", "test"], + "MACS": ["train", "test"], + "freesound": ["train", "test"], + "FSD50K": ["train", "test", "valid"], + "fsd50k_class_label": ["train", "test", "valid"], + "esc50": ["train", "test"], + "audiostock": ["train", "test"], + "freesound_no_overlap_noesc50": ["train", "test"], + "epidemic_sound_effects": ["train", "test"], + "VGGSound": ["train", "test"], + "urbansound8k_class_label": ["train", "test"], + "audioset_t5": ["balanced_train", "unbalanced_train", "eval"], + "epidemic_sound_effects_t5": ["train", "test"], + "WavText5K": ["train", "test"], + "esc50_no_overlap": ["train", "test"], + "usd8k_no_overlap": ["train", "test"], + "fsd50k_200_class_label": ["train", "test", "valid"], +} + + +def freeze_batch_norm_2d(module, module_match={}, name=""): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance( + module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) + ): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = ".".join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +def exist(dataset_name, dataset_type): + """ + Check if dataset exists + """ + if dataset_type in dataset_split[dataset_name]: + return True + else: + return False + + +def get_tar_path_from_dataset_name( + dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None +): + """ + Get tar path from dataset name and type + """ + output = [] + for n in dataset_names: + if full_dataset is not None and n in full_dataset: + current_dataset_types = dataset_split[n] + else: + current_dataset_types = dataset_types + for s in current_dataset_types: + tmp = [] + if islocal: + sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json" + if not os.path.exists(sizefilepath_): + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + else: + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + if not os.path.exists(sizefilepath_): + continue + sizes = json.load(open(sizefilepath_, "r")) + for k in sizes.keys(): + if islocal: + tmp.append(f"{dataset_path}/{n}/{s}/{k}") + else: + tmp.append( + f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -" + ) + if proportion != 1: + tmp = random.sample(tmp, int(proportion * len(tmp))) + output.append(tmp) + return sum(output, []) + + +def get_tar_path_from_txts(txt_path, islocal, proportion=1): + """ + Get tar path from txt path + """ + if isinstance(txt_path, (list, tuple)): + return sum( + [ + get_tar_path_from_txts( + txt_path[i], islocal=islocal, proportion=proportion + ) + for i in range(len(txt_path)) + ], + [], + ) + if isinstance(txt_path, str): + with open(txt_path) as f: + lines = f.readlines() + if islocal: + lines = [ + lines[i] + .split("\n")[0] + .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/") + for i in range(len(lines)) + ] + else: + lines = [ + lines[i].split("\n")[0].replace(".tar", ".tar -") + for i in range(len(lines)) + ] + if proportion != 1: + print("Sampling tars with proportion of {}".format(proportion)) + lines = random.sample(lines, int(proportion * len(lines))) + return lines + + +def get_mix_lambda(mixup_alpha, batch_size): + mixup_lambdas = [ + np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size) + ] + return np.array(mixup_lambdas).astype(np.float32) + + +def do_mixup(x, mixup_lambda): + """ + Args: + x: (batch_size , ...) + mixup_lambda: (batch_size,) + Returns: + out: (batch_size, ...) + """ + out = ( + x.transpose(0, -1) * mixup_lambda + + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda) + ).transpose(0, -1) + return out + + +def interpolate(x, ratio): + """Interpolate data in time domain. This is used to compensate the + resolution reduction in downsampling of a CNN. + + Args: + x: (batch_size, time_steps, classes_num) + ratio: int, ratio to interpolate + Returns: + upsampled: (batch_size, time_steps * ratio, classes_num) + """ + (batch_size, time_steps, classes_num) = x.shape + upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) + upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) + return upsampled + + +def pad_framewise_output(framewise_output, frames_num): + """Pad framewise_output to the same length as input frames. The pad value + is the same as the value of the last frame. + Args: + framewise_output: (batch_size, frames_num, classes_num) + frames_num: int, number of frames to pad + Outputs: + output: (batch_size, frames_num, classes_num) + """ + pad = framewise_output[:, -1:, :].repeat( + 1, frames_num - framewise_output.shape[1], 1 + ) + """tensor for padding""" + + output = torch.cat((framewise_output, pad), dim=1) + """(batch_size, frames_num, classes_num)""" + + +def process_ipc(index_path, classes_num, filename): + # load data + logging.info("Load Data...............") + ipc = [[] for _ in range(classes_num)] + with h5py.File(index_path, "r") as f: + for i in tqdm(range(len(f["target"]))): + t_class = np.where(f["target"][i])[0] + for t in t_class: + ipc[t].append(i) + print(ipc) + np.save(filename, ipc) + logging.info("Load Data Succeed...............") + + +def save_to_dict(s, o_={}): + sp = s.split(": ") + o_.update({sp[0]: float(sp[1])}) + return o_ + + +def get_data_from_log(txt_path): + """ + Output dictionary from out.txt log file + """ + with open(txt_path) as f: + lines = f.readlines() + val_data = {} + train_data = {} + train_losses = [] + train_losses_epoch = [] + for i in range(len(lines)): + if "| INFO |" in lines[i]: + if "Eval Epoch" in lines[i]: + if "val_loss" in lines[i]: + # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", "")) + line = lines[i].split("Eval Epoch: ")[-1] + num_epoch = int(line.split(" ")[0].split(" ")[0]) + d = { + line.split(" ")[0] + .split(" ")[1] + .replace(":", ""): float(line.split(" ")[0].split(" ")[-1]) + } + for i in range(1, len(line.split(" "))): + d = save_to_dict(line.split(" ")[i], d) + val_data[num_epoch] = d + elif "Train Epoch" in lines[i]: + num_epoch = int(lines[i].split("Train Epoch: ")[1][0]) + loss = float(lines[i].split("Loss: ")[-1].split(" (")[0]) + train_losses.append(loss) + train_losses_epoch.append(num_epoch) + for i in range(len(train_losses)): + train_data[i] = { + "num_epoch": train_losses_epoch[i], + "train_loss": train_losses[i], + } + return train_data, val_data + + +def save_p(obj, filename): + import pickle + + try: + from deepdiff import DeepDiff + except: + os.system("pip install deepdiff") + from deepdiff import DeepDiff + with open(filename, "wb") as file: + pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol + with open(filename, "rb") as file: + z = pickle.load(file) + assert ( + DeepDiff(obj, z, ignore_string_case=True) == {} + ), "there is something wrong with the saving process" + return + + +def load_p(filename): + import pickle + + with open(filename, "rb") as file: + z = pickle.load(file) + return z + + +def save_json(data, name="data.json"): + import json + + with open(name, "w") as fp: + json.dump(data, fp) + return + + +def load_json(name): + import json + + with open(name, "r") as fp: + data = json.load(fp) + return data + + +def load_class_label(path): + # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing + # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array + out = None + if path is not None: + if pathlib.Path(path).suffix in [".pkl", ".pickle"]: + out = load_p(path) + elif pathlib.Path(path).suffix in [".json", ".txt"]: + out = load_json(path) + elif pathlib.Path(path).suffix in [".npy", ".npz"]: + out = np.load(path) + elif pathlib.Path(path).suffix in [".csv"]: + import pandas as pd + + out = pd.read_csv(path) + return out + # if out is None: + # return None + # else: + # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False) + # val = Array('i', out.values(), lock=False) + # return (key, val) + + +from torch import optim + + +def get_optimizer(params, lr, betas, eps, momentum, optimizer_name): + if optimizer_name.lower() == "adamw": + optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps) + elif optimizer_name.lower() == "sgd": + optimizer = optim.SGD(params, lr=lr, momentum=momentum) + elif optimizer_name.lower() == "adam": + optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps) + else: + raise ValueError("optimizer name is not correct") + return optimizer diff --git a/audioldm2/clap/training/__init__.py b/audioldm2/clap/training/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm2/clap/training/audioset_textmap.npy b/audioldm2/clap/training/audioset_textmap.npy new file mode 100755 index 0000000000000000000000000000000000000000..3da4c92d3819aaec11e5f576464a9973a6df811b --- /dev/null +++ b/audioldm2/clap/training/audioset_textmap.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bada103070d92f9eadd33e1b4f45ec8583f59080ef218c966b43294bd4c86d5b +size 84448 diff --git a/audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz b/audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/audioldm2/clap/training/data.py b/audioldm2/clap/training/data.py new file mode 100755 index 0000000000000000000000000000000000000000..ae01406c63a9b1c678151f67dacd7ea192cb84f2 --- /dev/null +++ b/audioldm2/clap/training/data.py @@ -0,0 +1,865 @@ +import json +import logging +import os +import random +import h5py +from dataclasses import dataclass +import numpy as np +import pandas as pd +import torch +import torchvision.datasets as datasets +from PIL import Image +from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler +from torch.utils.data.distributed import DistributedSampler +import soundfile as sf +import io +from pathlib import Path +# import wget + +from audioldm2.clap.open_clip.utils import get_tar_path_from_dataset_name +from audioldm2.clap.open_clip.utils import load_class_label + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +try: + import torchaudio +except ImportError: + torchaudio = None + +from audioldm2.clap.open_clip import tokenize + + +def tokenizer(text): + return tokenize(text).squeeze(0) + + +from transformers import RobertaTokenizer + +tokenize = RobertaTokenizer.from_pretrained("roberta-base") + + +def tokenizer(text): + result = tokenize( + text, + padding="max_length", + truncation=True, + max_length=77, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} + + +# initizlied the audioset map +_AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy") +_AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True) + + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def float32_to_int16(x): + x = np.clip(x, a_min=-1.0, a_max=1.0) + return (x * 32767.0).astype(np.int16) + + +# For Toy Dataset +class ToyDataset(Dataset): + def __init__(self, index_path, ipc, config, eval_mode=False): + """Toy Dataset for testing the audioset input with text labels + Parameters + ---------- + index_path: str + the link to the h5 file of each audio + idc: str + the link to the npy file, the number of samples in each class + config: dict + the audio cfg file + eval_model (bool): to indicate if the dataset is a testing dataset + """ + self.audio_cfg = config["audio_cfg"] + self.text_cfg = config["text_cfg"] + self.fp = h5py.File(index_path, "r") + self.ipc = np.load(ipc, allow_pickle=True) + self.total_size = len(self.fp["audio_name"]) + self.classes_num = self.audio_cfg["class_num"] + self.eval_mode = eval_mode + + if not eval_mode: + self.generate_queue() + else: + self.queue = [] + for i in range(self.total_size): + target = self.fp["target"][i] + if np.sum(target) > 0: + self.queue.append(i) + self.total_size = len(self.queue) + logging.info("total dataset size: %d" % (self.total_size)) + logging.info("class num: %d" % (self.classes_num)) + + def time_shifting(self, x): + frame_num = len(x) + shift_len = random.randint(0, frame_num - 1) + new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0) + return new_sample + + def generate_queue(self): + self.queue = [] + while len(self.queue) < self.total_size: + class_set = [*range(self.classes_num)] + random.shuffle(class_set) + self.queue += [ + self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set + ] + self.queue = self.queue[: self.total_size] + + logging.info("queue regenerated:%s" % (self.queue[-5:])) + + def crop_wav(self, x): + crop_size = self.audio_cfg["crop_size"] + crop_pos = random.randint(0, len(x) - crop_size - 1) + return x[crop_pos : crop_pos + crop_size] + + def prompt_text(self, target): + events = _AUDIOSET_MAP[np.where(target > 0)] + event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1] + text = tokenize(event_text)[0] + return text + + def __getitem__(self, index): + """Load waveform, text, and target of an audio clip + + Parameters + ---------- + index: int + the index number + Return + ------ + output: dict { + "hdf5_path": str, + "index_in_hdf5": int, + "audio_name": str, + "waveform": list (audio_length,), + "target": list (class_num, ), + "text": torch.tensor (context_length,) + } + the output dictionary + """ + s_index = self.queue[index] + + audio_name = self.fp["audio_name"][s_index].decode() + # Hardcode here CHANGE + hdf5_path = ( + self.fp["hdf5_path"][s_index] + .decode() + .replace( + "../workspace", + "/home/la/kechen/Research/ke_zsasp/workspace", + ) + ) + r_idx = self.fp["index_in_hdf5"][s_index] + target = self.fp["target"][s_index].astype(np.float32) + text = self.prompt_text(target) + with h5py.File(hdf5_path, "r") as f: + waveform = int16_to_float32(f["waveform"][r_idx])[ + : self.audio_cfg["clip_samples"] + ] + assert ( + len(waveform) == self.audio_cfg["clip_samples"] + ), "The sample length is not match" + # Time shift + # if (self.config.enable_time_shift) and (not self.eval_mode): + # waveform = self.time_shifting(waveform) + # # Label Enhance + # if (self.config.crop_size is not None) and (not self.eval_mode): + # waveform = self.crop_wav(waveform) + # # the label enhance rate is fixed 0.5 + # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5: + # kidx = np.where(target)[0] + # for k in kidx: + # for add_key in self.class_map[k][1]: + # target[add_key] = 1.0 + # if len(self.class_map[k][2]) > 0: + # add_key = random.choice(self.class_map[k][2]) + # target[add_key] = 1.0 + + # missing the text input + mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :] + mel_spec = ( + torch.cat( + [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0 + ) + .cpu() + .numpy() + ) + longer = random.choice([True, False]) + if longer == False: + mel_spec[1:, :, :] = 0.0 + data_dict = { + "hdf5_path": hdf5_path, + "index_in_hdf5": r_idx, + "audio_name": audio_name, + "waveform": waveform, + "class_label": target, + "text": text, + "longer": longer, + "mel_fusion": mel_spec, + } + return data_dict + + def __len__(self): + return self.total_size + + +class CsvDataset(Dataset): + def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"): + logging.debug(f"Loading csv data from {input_filename}.") + df = pd.read_csv(input_filename, sep=sep) + + self.images = df[img_key].tolist() + self.captions = df[caption_key].tolist() + self.transforms = transforms + logging.debug("Done loading data.") + + def __len__(self): + return len(self.captions) + + def __getitem__(self, idx): + images = self.transforms(Image.open(str(self.images[idx]))) + texts = tokenize([str(self.captions[idx])])[0] + return images, texts + + +@dataclass +class DataInfo: + dataloader: DataLoader + sampler: DistributedSampler + + +def preprocess_txt(text): + return tokenize([str(text)])[0] + + +# def get_dataset_size(shards, sizefilepath_=None, is_local=True): +# if isinstance(shards, list): +# size_list = [] +# for s in shards: +# size_list.append( +# get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0] +# ) +# else: +# if not is_local: +# for n in dataset_split.keys(): +# if n in shards.split("/"): +# break +# for s in dataset_split[n]: +# if s in shards.split("/"): +# break +# sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" +# shards_list = list(braceexpand.braceexpand(shards)) +# dir_path = os.path.dirname(shards) +# if sizefilepath_ is not None: +# sizes = json.load(open(sizefilepath_, "r")) +# total_size = sum( +# [ +# int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))]) +# for shard in shards_list +# ] +# ) +# else: +# sizes_filename = os.path.join(dir_path, "sizes.json") +# len_filename = os.path.join(dir_path, "__len__") +# if os.path.exists(sizes_filename): +# sizes = json.load(open(sizes_filename, "r")) +# total_size = sum( +# [int(sizes[os.path.basename(shard)]) for shard in shards_list] +# ) +# elif os.path.exists(len_filename): +# # FIXME this used to be eval(open(...)) but that seemed rather unsafe +# total_size = ast.literal_eval(open(len_filename, "r").read()) +# else: +# raise Exception( +# "Cannot find sizes file for dataset. Please specify the path to the file." +# ) +# # total_size = None # num samples undefined +# # some common dataset sizes (at time of authors last download) +# # cc3m-train: 2905954 +# # cc12m: 10968539 +# # LAION-400m: 407332084 +# num_shards = len(shards_list) +# if isinstance(shards, list): +# return sum(size_list), len(shards) +# else: +# return total_size, num_shards + + +def get_imagenet(args, preprocess_fns, split): + assert split in ["train", "val", "v2"] + is_train = split == "train" + preprocess_train, preprocess_val = preprocess_fns + + if split == "v2": + from imagenetv2_pytorch import ImageNetV2Dataset + + dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) + else: + if is_train: + data_path = args.imagenet_train + preprocess_fn = preprocess_train + else: + data_path = args.imagenet_val + preprocess_fn = preprocess_val + assert data_path + + dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) + + if is_train: + idxs = np.zeros(len(dataset.targets)) + target_array = np.array(dataset.targets) + k = 50 + for c in range(1000): + m = target_array == c + n = len(idxs[m]) + arr = np.zeros(n) + arr[:k] = 1 + np.random.shuffle(arr) + idxs[m] = arr + + idxs = idxs.astype("int") + sampler = SubsetRandomSampler(np.where(idxs)[0]) + else: + sampler = None + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.workers, + sampler=sampler, + ) + + return DataInfo(dataloader, sampler) + + +def count_samples(dataloader): + os.environ["WDS_EPOCH"] = "0" + n_elements, n_batches = 0, 0 + for images, texts in dataloader: + n_batches += 1 + n_elements += len(images) + assert len(images) == len(texts) + return n_elements, n_batches + + +def filter_no_caption(sample): + return "txt" in sample + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") + return True + + +_SHARD_SHUFFLE_SIZE = 2000 +_SHARD_SHUFFLE_INITIAL = 500 +_SAMPLE_SHUFFLE_SIZE = 5000 +_SAMPLE_SHUFFLE_INITIAL = 1000 + + +# def sample_prop(sizefile, inputs, proportion, is_local=True): +# """ +# Sample a proportion of the data. +# """ +# file_path_dict = { +# os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0] +# for i in range(len(inputs)) +# } +# sampled_filepath_dict = {} +# sampled_size_dict = {} +# if not is_local: +# if os.path.exists("sizes.json"): +# os.remove("sizes.json") +# wget.download(sizefile, "sizes.json") +# sizefile = "sizes.json" +# with open(sizefile, "r", encoding="UTF-8") as f: +# load_dict = json.load(f) +# L = int(len(file_path_dict) * proportion) +# subkeys = random.sample(file_path_dict.keys(), L) +# for k in subkeys: +# sampled_size_dict[k] = load_dict[k] +# sampled_filepath_dict[k] = file_path_dict[k] +# return ( +# sum(sampled_size_dict.values()), +# L, +# [os.path.join(v, k) for k, v in sampled_filepath_dict.items()], +# sampled_size_dict, +# ) + + +def get_mel(audio_data, audio_cfg): + # mel shape: (n_mels, T) + mel = torchaudio.transforms.MelSpectrogram( + sample_rate=audio_cfg["sample_rate"], + n_fft=audio_cfg["window_size"], + win_length=audio_cfg["window_size"], + hop_length=audio_cfg["hop_size"], + center=True, + pad_mode="reflect", + power=2.0, + norm=None, + onesided=True, + n_mels=64, + f_min=audio_cfg["fmin"], + f_max=audio_cfg["fmax"], + ).to(audio_data.device) + mel = mel(audio_data) + # we use log mel spectrogram as input + mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) + return mel.T # (T, n_mels) + + +def get_audio_features( + audio_data, mel, max_len, data_truncating, data_filling, audio_cfg +): + """ + Calculate and add audio features to sample. + Sample: a dict containing all the data of current sample. + audio_data: a tensor of shape (T) containing audio data. + max_len: the maximum length of audio data. + data_truncating: the method of truncating data. + data_filling: the method of filling data. + audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg']. + """ + sample = {} + + # assert audio_data.size(-1) <= max_len, str(audio_data.size()) + + # split to three parts + chunk_frames = ( + max_len // audio_cfg["hop_size"] + 1 + ) # the +1 related to how the spectrogram is computed + mel = mel[:chunk_frames] + + audio_data = audio_data[..., :max_len] + sample["mel_fusion"] = mel + longer = torch.tensor([True]) + + sample["longer"] = longer + sample["waveform"] = audio_data + + return sample + + +def preprocess( + sample, + audio_ext, + text_ext, + max_len, + audio_cfg, + class_index_dict=None, + data_filling="pad", + data_truncating="rand_trunc", + text_augment_selection=None, +): + """ + Preprocess a single sample for wdsdataloader. + """ + audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) + audio_data = int16_to_float32(float32_to_int16(audio_data)) + audio_data = torch.tensor(audio_data).float() + + # TODO: (yusong) to be include in the future + # # if torchaudio not installed, use soundfile to load audio + # if torchaudio is None: + # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) + # audio_data = torch.tensor(audio_data).float() + # else: + # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py + # with tempfile.TemporaryDirectory() as dirname: + # os.makedirs(dirname, exist_ok=True) + # fname = os.path.join(dirname, f"file.flac") + # with open(fname, "wb") as stream: + # stream.write(sample[audio_ext]) + # audio_data, orig_sr = torchaudio.load(fname) + # audio_data = audio_data[0, :].float() + + sample = get_audio_features( + sample, audio_data, max_len, data_truncating, data_filling, audio_cfg + ) + del sample[audio_ext] + + try: + json_dict_raw = json.loads(sample[text_ext].decode("utf-8")) + except: + print("sample[__url__]:", sample["__url__"]) + + # For selecting augmented text from dataset + if text_augment_selection is None or text_augment_selection == "none": + texts = json_dict_raw["text"] + elif text_augment_selection == "all": + if "text_augment_all" in json_dict_raw.keys(): + texts = json_dict_raw["text_augment_all"] + else: + texts = json_dict_raw["text"] + elif text_augment_selection == "augment_only": + if "text_augment_all" in json_dict_raw.keys(): + if json_dict_raw["text_augment_t5"] is None: + texts = json_dict_raw["text"] + else: + texts = json_dict_raw["text_augment_t5"] + else: + texts = json_dict_raw["text"] + else: + raise NotImplementedError( + f"text_augment_selection {text_augment_selection} not implemented" + ) + sample["full_text"] = texts + + if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1: + texts = random.choice(texts) + sample["raw_text"] = texts + sample["text"] = tokenizer(texts) # text shape: [num_token] + if class_index_dict is not None: + # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing + # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array + # key, val = class_index_dict + # key = key[:].split('\n') + # _dict = {k: v for k, v in zip(key, val)} + sample["class_label"] = np.zeros(len(class_index_dict.keys())) + for x in json_dict_raw["tag"]: + sample["class_label"][class_index_dict[x]] = 1 + sample["class_label"] = torch.tensor(sample["class_label"]).float() + del sample[text_ext] + sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext + sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext + sample["audio_orig_sr"] = orig_sr + return sample + + +def collate_fn(batch): + """ + Collate function for wdsdataloader. + batch: a list of dict, each dict is a sample + """ + # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend. + batch_dict = {} + for k in batch[0].keys(): + if isinstance(batch[0][k], dict): # dealwith bert tokenizer output + batch_dict[k] = {} + for kk in batch[0][k].keys(): + tmp = [] + for i in range(len(batch)): + tmp.append(batch[i][k][kk]) + batch_dict[k][kk] = torch.vstack(tmp) + elif isinstance(batch[0][k], torch.Tensor): + batch_dict[k] = torch.stack([sample[k] for sample in batch]) + elif isinstance(batch[0][k], np.ndarray): + batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch])) + else: + batch_dict[k] = [sample[k] for sample in batch] + return batch_dict + + +# def get_wds_dataset( +# args, +# model_cfg, +# is_train, +# audio_ext="flac", +# text_ext="json", +# max_len=480000, +# proportion=1.0, +# sizefilepath_=None, +# is_local=None, +# ): +# """ +# Get a dataset for wdsdataloader. +# """ +# if is_local is None and (not args.remotedata is None): +# is_local = not args.remotedata + +# input_shards = args.train_data if is_train else args.val_data +# assert input_shards is not None + +# if not sizefilepath_ is None: +# sizefilepath = sizefilepath_ +# else: +# sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json") + +# if proportion != 1.0: +# num_samples, num_shards, input_shards, _ = sample_prop( +# sizefilepath, input_shards, proportion, is_local=is_local +# ) +# else: +# num_samples, num_shards = get_dataset_size( +# input_shards, sizefilepath_=sizefilepath_, is_local=is_local +# ) + +# if not num_samples: +# if is_train: +# num_samples = args.train_num_samples +# if not num_samples: +# raise RuntimeError( +# "Currently, number of dataset samples must be specified for training dataset. " +# "Please specify via `--train-num-samples` if no dataset length info present." +# ) +# else: +# num_samples = ( +# args.val_num_samples or 0 +# ) # eval will just exhaust the iterator if not specified + +# pipeline = [wds.SimpleShardList(input_shards)] +# # at this point we have an iterator over all the shards +# # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node +# if is_train or args.parallel_eval: +# pipeline.extend( +# [ +# wds.detshuffle( +# bufsize=_SHARD_SHUFFLE_SIZE, +# initial=_SHARD_SHUFFLE_INITIAL, +# seed=args.seed, +# ), +# wds.split_by_node, +# wds.split_by_worker, +# # at this point, we have an iterator over the shards assigned to each worker at each node +# wds.tarfile_to_samples(handler=log_and_continue), +# wds.shuffle( +# bufsize=_SAMPLE_SHUFFLE_SIZE, +# initial=_SAMPLE_SHUFFLE_INITIAL, +# rng=random.Random(args.seed), +# ), +# # wds.repeatedly, # FIXME determine if this is beneficial +# ] +# ) +# else: +# pipeline.extend( +# [ +# wds.split_by_worker, +# # at this point, we have an iterator over the shards assigned to each worker +# wds.tarfile_to_samples(handler=log_and_continue), +# ] +# ) +# pipeline.append( +# wds.map( +# partial( +# preprocess, +# audio_ext=audio_ext, +# text_ext=text_ext, +# max_len=max_len, +# audio_cfg=model_cfg["audio_cfg"], +# class_index_dict=copy.deepcopy(args.class_index_dict), +# data_filling=args.data_filling, +# data_truncating=args.data_truncating, +# text_augment_selection=args.text_augment_selection, +# ) +# ), +# ) + +# pipeline.append( +# wds.batched( +# args.batch_size, +# partial=not (is_train or args.parallel_eval), +# collation_fn=collate_fn, +# ) +# ) + +# dataset = wds.DataPipeline(*pipeline) +# if is_train or args.parallel_eval: +# # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples. +# # (yusong): See comments below. +# # roll over and repeat a few samples to get same number of full batches on each node +# global_batch_size = args.batch_size * args.world_size +# num_batches = math.ceil(num_samples / global_batch_size) +# num_workers = max(1, args.workers) +# num_worker_batches = math.ceil( +# num_batches / num_workers +# ) # per dataloader worker +# num_batches = num_worker_batches * num_workers +# num_samples = num_batches * global_batch_size +# dataset = dataset.with_epoch( +# num_worker_batches +# ) # each worker is iterating over this +# else: +# # last batches are partial, eval is done on single (master) node +# num_batches = math.ceil(num_samples / args.batch_size) + +# kwargs = {} +# if args.horovod: # multi-node training on summit +# kwargs["multiprocessing_context"] = "forkserver" + +# dataloader = wds.WebLoader( +# dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs +# ) + +# # FIXME not clear which approach is better, with_epoch before vs after dataloader? +# # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 +# # if is_train: +# # # roll over and repeat a few samples to get same number of full batches on each node +# # global_batch_size = args.batch_size * args.world_size +# # num_batches = math.ceil(num_samples / global_batch_size) +# # num_workers = max(1, args.workers) +# # num_batches = math.ceil(num_batches / num_workers) * num_workers +# # num_samples = num_batches * global_batch_size +# # dataloader = dataloader.with_epoch(num_batches) +# # else: +# # # last batches are partial, eval is done on single (master) node +# # num_batches = math.ceil(num_samples / args.batch_size) + +# # add meta-data to dataloader instance for convenience +# dataloader.num_batches = num_batches +# dataloader.num_samples = num_samples + +# return DataInfo(dataloader, None) + + +def wds_batch_list2dict( + batch, + keys=[ + "__url__", + "__key__", + "waveform", + "text", + "raw_text", + "audio_name", + "text_name", + "audio_orig_sr", + ], +): + """ + Return a dictionary of the batch, with keys as the names of the fields. + """ + assert len(keys) == len( + batch + ), "batch must have same number of keys as keys argument" + return {keys[i]: batch[i] for i in range(len(batch))} + + +def get_csv_dataset(args, preprocess_fn, is_train): + input_filename = args.train_data if is_train else args.val_data + assert input_filename + dataset = CsvDataset( + input_filename, + preprocess_fn, + img_key=args.csv_img_key, + caption_key=args.csv_caption_key, + sep=args.csv_separator, + ) + num_samples = len(dataset) + sampler = DistributedSampler(dataset) if args.distributed and is_train else None + shuffle = is_train and sampler is None + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=shuffle, + num_workers=args.workers, + pin_memory=True, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + + +def get_toy_dataset(args, model_cfg, is_train): + index_path = args.train_data if is_train else args.val_data + ipc_path = args.train_ipc if is_train else args.val_ipc + assert index_path and ipc_path + eval_mode = not is_train + dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode) + + num_samples = len(dataset) + sampler = ( + DistributedSampler(dataset, shuffle=False) + if args.distributed and is_train + else None + ) + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + + +def get_dataset_fn(data_path, dataset_type): + if dataset_type == "webdataset": + return get_wds_dataset + elif dataset_type == "csv": + return get_csv_dataset + elif dataset_type == "auto": + ext = data_path.split(".")[-1] + if ext in ["csv", "tsv"]: + return get_csv_dataset + elif ext in ["tar"]: + return get_wds_dataset + else: + raise ValueError( + f"Tried to figure out dataset type, but failed for extention {ext}." + ) + elif dataset_type == "toy": + return get_toy_dataset + else: + raise ValueError(f"Unsupported dataset type: {dataset_type}") + + +def get_data(args, model_cfg): + data = {} + + args.class_index_dict = load_class_label(args.class_label_path) + + if args.datasetinfos is None: + args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] + if args.dataset_type == "webdataset": + args.train_data = get_tar_path_from_dataset_name( + args.datasetnames, + args.datasetinfos, + islocal=not args.remotedata, + proportion=args.dataset_proportion, + dataset_path=args.datasetpath, + full_dataset=args.full_train_dataset, + ) + + if args.full_train_dataset is None: + args.full_train_dataset = [] + if args.exclude_eval_dataset is None: + args.exclude_eval_dataset = [] + excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset + + val_dataset_names = ( + [n for n in args.datasetnames if n not in excluded_eval_datasets] + if excluded_eval_datasets + else args.datasetnames + ) + args.val_dataset_names = val_dataset_names + args.val_data = get_tar_path_from_dataset_name( + val_dataset_names, + ["valid", "test", "eval"], + islocal=not args.remotedata, + proportion=1, + dataset_path=args.datasetpath, + full_dataset=None, + ) + + if args.train_data: + data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( + args, model_cfg, is_train=True + ) + + if args.val_data: + data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( + args, model_cfg, is_train=False + ) + + return data diff --git a/audioldm2/clap/training/params.py b/audioldm2/clap/training/params.py new file mode 100755 index 0000000000000000000000000000000000000000..0cc1a0e2d982e900988cf5a4b24b2e59b093537b --- /dev/null +++ b/audioldm2/clap/training/params.py @@ -0,0 +1,563 @@ +import argparse + + +def get_default_params(model_name): + # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) + model_name = model_name.lower() + if "vit" in model_name: + return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} + else: + return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--train-data", + type=str, + default=None, + help="Path to h5 filewith training data", + ) + parser.add_argument( + "--val-data", + type=str, + default=None, + help="Path to h5 file with validation data", + ) + parser.add_argument( + "--freeze-text", + default=False, + action="store_true", + help="if you need to freeze the text encoder, make this True", + ) + parser.add_argument( + "--freeze-text-after", + type=int, + default=-1, + help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it", + ) + parser.add_argument( + "--train-ipc", + type=str, + default=None, + help="Path to npy file of the number of instance per class in training data", + ) + parser.add_argument( + "--val-ipc", + type=str, + default=None, + help="Path to npy file of the number of instance per class in validation data", + ) + parser.add_argument( + "--train-num-samples", + type=int, + default=None, + help="Number of samples in dataset. Required for webdataset if not available in info file.", + ) + parser.add_argument( + "--val-num-samples", + type=int, + default=None, + help="Number of samples in dataset. Useful for webdataset if not available in info file.", + ) + parser.add_argument( + "--dataset-type", + choices=["webdataset", "csv", "auto", "toy"], + default="auto", + help="Which type of dataset to process.", + ) + parser.add_argument( + "--csv-separator", + type=str, + default="\t", + help="For csv-like datasets, which separator to use.", + ) + parser.add_argument( + "--csv-img-key", + type=str, + default="filepath", + help="For csv-like datasets, the name of the key for the image paths.", + ) + parser.add_argument( + "--csv-caption-key", + type=str, + default="title", + help="For csv-like datasets, the name of the key for the captions.", + ) + parser.add_argument( + "--imagenet-val", + type=str, + default=None, + help="Path to imagenet val set for conducting zero shot evaluation.", + ) + parser.add_argument( + "--imagenet-v2", + type=str, + default=None, + help="Path to imagenet v2 for conducting zero shot evaluation.", + ) + parser.add_argument( + "--datasetnames", + nargs="+", + default=None, + help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects", + ) + parser.add_argument( + "--full-train-dataset", + nargs="+", + default=None, + help="Which dataset will be trained with all the subsets. (train+test)", + ) + parser.add_argument( + "--exclude-eval-dataset", + nargs="+", + default=None, + help="Which dataset will be excluded with evaluation", + ) + parser.add_argument( + "--datasetinfos", + nargs="+", + default=None, + help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval", + ) + parser.add_argument( + "--dataset-proportion", + type=float, + default=1.0, + help="How much proportion of dataset we want to train.", + ) + parser.add_argument( + "--remotedata", + default=False, + action="store_true", + help="if the dataset is remote, set this flag", + ) + parser.add_argument( + "--class-label-path", + type=str, + default=None, + help="The path of the class label pickle or csv.", + ) + parser.add_argument( + "--datasetpath", + type=str, + default="/mnt/audio_clip/webdataset_tar", + help="The path to the dataset", + ) + parser.add_argument( + "--logs", + type=str, + default="./logs/", + help="Where to store tensorboard logs. Use None to avoid storing logs.", + ) + parser.add_argument( + "--log-local", + action="store_true", + default=False, + help="log files on local master, otherwise global master only.", + ) + parser.add_argument( + "--name", + type=str, + default=None, + help="Optional identifier for the experiment when storing logs. Otherwise use current time.", + ) + parser.add_argument( + "--workers", type=int, default=1, help="Number of workers per GPU." + ) + parser.add_argument( + "--batch-size", type=int, default=64, help="Batch size per GPU." + ) + parser.add_argument( + "--epochs", type=int, default=32, help="Number of epochs to train for." + ) + parser.add_argument("--lr", type=float, default=None, help="Learning rate.") + parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") + parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") + parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") + parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.") + parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") + + parser.add_argument( + "--split-opt", + action="store_true", + default=False, + help="Use this flag to skip the learning rate decay.", + ) + parser.add_argument( + "--lr-pretrained", type=float, default=None, help="Learning rate for text." + ) + parser.add_argument( + "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text." + ) + parser.add_argument( + "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text." + ) + parser.add_argument( + "--eps-pretrained", type=float, default=None, help="Adam epsilon for text." + ) + parser.add_argument( + "--wd-pretrained", type=float, default=0.2, help="Weight decay for text." + ) + parser.add_argument( + "--momentum-pretrained", type=float, default=0.9, help="Momentum for text." + ) + parser.add_argument( + "--lr-new", type=float, default=None, help="Learning rate for audio." + ) + parser.add_argument( + "--beta1-new", type=float, default=None, help="Adam beta 1 for audio." + ) + parser.add_argument( + "--beta2-new", type=float, default=None, help="Adam beta 2 for audio." + ) + parser.add_argument( + "--eps-new", type=float, default=None, help="Adam epsilon for audio." + ) + parser.add_argument( + "--wd-new", type=float, default=0.2, help="Weight decay for audio." + ) + parser.add_argument( + "--momentum-new", type=float, default=0.9, help="Momentum for audio." + ) + parser.add_argument( + "--warmup", type=int, default=10000, help="Number of steps to warmup for." + ) + parser.add_argument( + "--use-bn-sync", + default=False, + action="store_true", + help="Whether to use batch norm sync.", + ) + parser.add_argument( + "--skip-scheduler", + action="store_true", + default=False, + help="Use this flag to skip the learning rate decay.", + ) + parser.add_argument( + "--save-frequency", type=int, default=1, help="How often to save checkpoints." + ) + parser.add_argument( + "--save-top-performance", + type=int, + default=0, + help="Save the top x performance weights if the value >0", + ) + parser.add_argument( + "--save-most-recent", + action="store_true", + default=False, + help="Always save the most recent model trained to epoch_latest.pt.", + ) + parser.add_argument( + "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." + ) + parser.add_argument( + "--val-frequency", + type=int, + default=1, + help="How often to run evaluation with val data.", + ) + parser.add_argument( + "--resume", + default=None, + type=str, + help="path to latest checkpoint (default: none)", + ) + parser.add_argument( + "--precision", + choices=["amp", "fp16", "fp32"], + default="amp", + help="Floating point precision.", + ) + parser.add_argument( + "--amodel", + type=str, + default="RN50", + help="Name of the audio backbone to use.", + ) + parser.add_argument( + "--tmodel", + type=str, + default="transformer", + help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]", + ) + parser.add_argument( + "--pretrained-audio", + default="", + type=str, + help="Use a pretrained audio model weights for the audio encoder of CLAP", + ) + parser.add_argument( + "--pretrained-text", + default="", + type=str, + help="Use a pretrained text model weights for the text encoder of CLAP", + ) + parser.add_argument( + "--pretrained", + default="", + type=str, + help="Use a pretrained CLIP model weights with the specified tag or file path.", + ) + parser.add_argument( + "--pretrained-image", + default=False, + action="store_true", + help="Load imagenet pretrained weights for image tower backbone if available.", + ) + parser.add_argument( + "--lock-image", + default=False, + action="store_true", + help="Lock full image tower by disabling gradients.", + ) + parser.add_argument( + "--lock-image-unlocked-groups", + type=int, + default=0, + help="Leave last n image tower layer groups unlocked.", + ) + parser.add_argument( + "--lock-image-freeze-bn-stats", + default=False, + action="store_true", + help="Freeze BatchNorm running stats in image tower for any locked layers.", + ) + parser.add_argument( + "--local-loss", + default=False, + action="store_true", + help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)", + ) + parser.add_argument( + "--gather-with-grad", + default=False, + action="store_true", + help="enable full distributed gradient for feature gather", + ) + parser.add_argument( + "--force-quick-gelu", + default=False, + action="store_true", + help="Force use of QuickGELU activation for non-OpenAI transformer models.", + ) + parser.add_argument( + "--torchscript", + default=False, + action="store_true", + help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", + ) + parser.add_argument( + "--trace", + default=False, + action="store_true", + help="torch.jit.trace the model for inference / eval only", + ) + # arguments for distributed training + parser.add_argument( + "--dist-url", + default="env://", + type=str, + help="url used to set up distributed training", + ) + parser.add_argument( + "--dist-backend", default="nccl", type=str, help="distributed backend" + ) + parser.add_argument( + "--report-to", + default="", + type=str, + help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']", + ) + parser.add_argument( + "--wandb-notes", default="", type=str, help="Notes if logging with wandb" + ) + parser.add_argument( + "--C", type=float, default=3.16, help="inverse regularizer for logistic reg." + ) + parser.add_argument( + "--debug", + default=False, + action="store_true", + help="If true, more information is logged.", + ) + parser.add_argument( + "--copy-codebase", + default=False, + action="store_true", + help="If true, we copy the entire base on the log diretory, and execute from there.", + ) + parser.add_argument( + "--horovod", + default=False, + action="store_true", + help="Use horovod for distributed training.", + ) + parser.add_argument( + "--ddp-static-graph", + default=False, + action="store_true", + help="Enable static graph optimization for DDP in PyTorch >= 1.11.", + ) + parser.add_argument( + "--no-set-device-rank", + default=False, + action="store_true", + help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", + ) + parser.add_argument("--seed", type=int, default=4242, help="Default random seed.") + + parser.add_argument( + "--top-k-checkpoint-select-dataset", + type=str, + default="all", + help="The dataset of selecting top-k checkpoint.", + ) + + # @R10, @R@5, @R1, mAP@10 + parser.add_argument( + "--top-k-checkpoint-select-metric", + type=str, + default="_R@10", + help="The metric for selecting top-k checkpoint.", + ) + parser.add_argument( + "--openai-model-cache-dir", + type=str, + default="~/.cache/clip", + help="Directory to download OpenAI models.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="adamw", + help="can be AdamW or SGD", + ) + parser.add_argument( + "--parallel-eval", + default=False, + action="store_true", + help="Eval in parallel (multi-GPU, multi-node).", + ) + + parser.add_argument( + "--no-eval", + default=False, + action="store_true", + help="Training without evaluation.", + ) + + parser.add_argument( + "--lp-mlp", + default=False, + action="store_true", + help="Linear Probe using MLP layer or not.", + ) + + parser.add_argument( + "--lp-freeze", + default=False, + action="store_true", + help="Linear Probe using Freeze CLAP or not", + ) + + parser.add_argument( + "--lp-act", + default="None", + type=str, + help="Options are ['relu','elu','prelu','softmax','sigmoid']", + ) + + parser.add_argument( + "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe." + ) + + parser.add_argument( + "--lp-metrics", + type=str, + default="map,mauc,acc", + help="Metrics of Linear Probe.", + ) + + parser.add_argument( + "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe" + ) + parser.add_argument( + "--kappa", + type=float, + default=0, + help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss", + ) + + parser.add_argument( + "--data-filling", + type=str, + default="pad", + help="type of data filling when the audio length is shorter than the max length." + "Can be one of the following: repeat, repeatpad, pad", + ) + parser.add_argument( + "--data-truncating", + type=str, + default="rand_trunc", + help="type of data truncation when the audio length is longer than the max length." + "Can be one of the following: rand_trunc, fusion", + ) + + parser.add_argument( + "--clap-mlploss", + default=False, + action="store_true", + help="Using MLP loss for CLAP model or not", + ) + + parser.add_argument( + "--wandb-id", + type=str, + default=None, + help="the id of wandb experiment to restore.", + ) + + parser.add_argument( + "--sleep", type=float, default=0, help="sleep n seconds before start training" + ) + + # variable length processing + parser.add_argument( + "--enable-fusion", + default=False, + action="store_true", + help="Enable feature funsion for variable-length data", + ) + + parser.add_argument( + "--fusion-type", + type=str, + default="None", + help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']", + ) + + parser.add_argument( + "--mixup", + default=False, + action="store_true", + help="Enable mixup in finetuning training.", + ) + parser.add_argument( + "--text-augment-selection", + type=str, + default=None, + help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']", + ) + + args = parser.parse_args() + + # If some params are not passed, we use the default values based on model name. + default_params = get_default_params(args.amodel) + for name, val in default_params.items(): + if getattr(args, name) is None: + setattr(args, name, val) + + return args diff --git a/audioldm2/hifigan/LICENSE b/audioldm2/hifigan/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..5afae394d6b37da0e12ba6b290d2512687f421ac --- /dev/null +++ b/audioldm2/hifigan/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Jungil Kong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/audioldm2/hifigan/__init__.py b/audioldm2/hifigan/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..34e055557bf2ecb457376663b67390543c71fb1f --- /dev/null +++ b/audioldm2/hifigan/__init__.py @@ -0,0 +1,8 @@ +from .models_v2 import Generator +from .models import Generator as Generator_old + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self diff --git a/audioldm2/hifigan/models.py b/audioldm2/hifigan/models.py new file mode 100755 index 0000000000000000000000000000000000000000..c4382cc39de0463f9b7c0f33f037dbc233e7cb36 --- /dev/null +++ b/audioldm2/hifigan/models.py @@ -0,0 +1,174 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock, self).__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm( + Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + # print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) diff --git a/audioldm2/hifigan/models_v2.py b/audioldm2/hifigan/models_v2.py new file mode 100755 index 0000000000000000000000000000000000000000..27a2df6b54bdd3a5b259645442624800ac0e8afe --- /dev/null +++ b/audioldm2/hifigan/models_v2.py @@ -0,0 +1,395 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm( + Conv1d(256, h.upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock1 if h.resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + u * 2, + u, + padding=u // 2 + u % 2, + output_padding=u % 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + # import ipdb; ipdb.set_trace() + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + # print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +################################################################################################## + +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# from torch.nn import Conv1d, ConvTranspose1d +# from torch.nn.utils import weight_norm, remove_weight_norm + +# LRELU_SLOPE = 0.1 + + +# def init_weights(m, mean=0.0, std=0.01): +# classname = m.__class__.__name__ +# if classname.find("Conv") != -1: +# m.weight.data.normal_(mean, std) + + +# def get_padding(kernel_size, dilation=1): +# return int((kernel_size * dilation - dilation) / 2) + + +# class ResBlock(torch.nn.Module): +# def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): +# super(ResBlock, self).__init__() +# self.h = h +# self.convs1 = nn.ModuleList( +# [ +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=dilation[0], +# padding=get_padding(kernel_size, dilation[0]), +# ) +# ), +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=dilation[1], +# padding=get_padding(kernel_size, dilation[1]), +# ) +# ), +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=dilation[2], +# padding=get_padding(kernel_size, dilation[2]), +# ) +# ), +# ] +# ) +# self.convs1.apply(init_weights) + +# self.convs2 = nn.ModuleList( +# [ +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=1, +# padding=get_padding(kernel_size, 1), +# ) +# ), +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=1, +# padding=get_padding(kernel_size, 1), +# ) +# ), +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=1, +# padding=get_padding(kernel_size, 1), +# ) +# ), +# ] +# ) +# self.convs2.apply(init_weights) + +# def forward(self, x): +# for c1, c2 in zip(self.convs1, self.convs2): +# xt = F.leaky_relu(x, LRELU_SLOPE) +# xt = c1(xt) +# xt = F.leaky_relu(xt, LRELU_SLOPE) +# xt = c2(xt) +# x = xt + x +# return x + +# def remove_weight_norm(self): +# for l in self.convs1: +# remove_weight_norm(l) +# for l in self.convs2: +# remove_weight_norm(l) + +# class Generator(torch.nn.Module): +# def __init__(self, h): +# super(Generator, self).__init__() +# self.h = h +# self.num_kernels = len(h.resblock_kernel_sizes) +# self.num_upsamples = len(h.upsample_rates) +# self.conv_pre = weight_norm( +# Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) +# ) +# resblock = ResBlock + +# self.ups = nn.ModuleList() +# for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): +# self.ups.append( +# weight_norm( +# ConvTranspose1d( +# h.upsample_initial_channel // (2**i), +# h.upsample_initial_channel // (2 ** (i + 1)), +# k, +# u, +# padding=(k - u) // 2, +# ) +# ) +# ) + +# self.resblocks = nn.ModuleList() +# for i in range(len(self.ups)): +# ch = h.upsample_initial_channel // (2 ** (i + 1)) +# for j, (k, d) in enumerate( +# zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) +# ): +# self.resblocks.append(resblock(h, ch, k, d)) + +# self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) +# self.ups.apply(init_weights) +# self.conv_post.apply(init_weights) + +# def forward(self, x): +# x = self.conv_pre(x) +# for i in range(self.num_upsamples): +# x = F.leaky_relu(x, LRELU_SLOPE) +# x = self.ups[i](x) +# xs = None +# for j in range(self.num_kernels): +# if xs is None: +# xs = self.resblocks[i * self.num_kernels + j](x) +# else: +# xs += self.resblocks[i * self.num_kernels + j](x) +# x = xs / self.num_kernels +# x = F.leaky_relu(x) +# x = self.conv_post(x) +# x = torch.tanh(x) + +# return x + +# def remove_weight_norm(self): +# print("Removing weight norm...") +# for l in self.ups: +# remove_weight_norm(l) +# for l in self.resblocks: +# l.remove_weight_norm() +# remove_weight_norm(self.conv_pre) +# remove_weight_norm(self.conv_post) diff --git a/audioldm2/latent_diffusion/__init__.py b/audioldm2/latent_diffusion/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm2/latent_diffusion/models/__init__.py b/audioldm2/latent_diffusion/models/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm2/latent_diffusion/models/ddim.py b/audioldm2/latent_diffusion/models/ddim.py new file mode 100755 index 0000000000000000000000000000000000000000..0c07207af7959847552805f00831122304b4330e --- /dev/null +++ b/audioldm2/latent_diffusion/models/ddim.py @@ -0,0 +1,487 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm + +from audioldm2.latent_diffusion.modules.diffusionmodules.util import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, + extract_into_tensor, +) + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + self.device = device + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != self.device: + attr = attr.to(self.device) + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + **kwargs, + ): + # if conditioning is not None: + # if isinstance(conditioning, dict): + # ctmp = conditioning[list(conditioning.keys())[0]] + # while isinstance(ctmp, list): ctmp = ctmp[0] + # cbs = ctmp.shape[0] + # if cbs != batch_size: + # print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + # elif isinstance(conditioning, list): + # for ctmp in conditioning: + # if ctmp.shape[0] != batch_size: + # print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + # else: + # if conditioning.shape[0] != batch_size: + # print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ucg_schedule=None, + ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample( + x0, ts + ) # TODO: deterministic forward pass? + img = img_orig * mask + (1.0 - mask) * img + + if ucg_schedule is not None: + assert len(ucg_schedule) == len(time_range) + unconditional_guidance_scale = ucg_schedule[i] + + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + img, pred_x0 = outs + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: + model_output = self.model.apply_model(x, t, c) + else: + x_in = x + t_in = t + + assert isinstance(c, dict) + assert isinstance(unconditional_conditioning, dict) + + model_uncond = self.model.apply_model( + x_in, t_in, unconditional_conditioning + ) + model_t = self.model.apply_model(x_in, t_in, c) + + model_output = model_uncond + unconditional_guidance_scale * ( + model_t - model_uncond + ) + + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.model.parameterization == "eps", "not implemented" + e_t = score_corrector.modify_score( + self.model, e_t, x, t, c, **corrector_kwargs + ) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + raise NotImplementedError() + + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def encode( + self, + x0, + c, + t_enc, + use_original_steps=False, + return_intermediates=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + callback=None, + ): + num_reference_steps = ( + self.ddpm_num_timesteps + if use_original_steps + else self.ddim_timesteps.shape[0] + ) + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc="Encoding Image"): + t = torch.full( + (x0.shape[0],), i, device=self.model.device, dtype=torch.long + ) + if unconditional_guidance_scale == 1.0: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.model.apply_model( + torch.cat((x_next, x_next)), + torch.cat((t, t)), + torch.cat((unconditional_conditioning, c)), + ), + 2, + ) + noise_pred = e_t_uncond + unconditional_guidance_scale * ( + noise_pred - e_t_uncond + ) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = ( + alphas_next[i].sqrt() + * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) + * noise_pred + ) + x_next = xt_weighted + weighted_noise_pred + if ( + return_intermediates + and i % (num_steps // return_intermediates) == 0 + and i < num_steps - 1 + ): + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + if callback: + callback(i) + + out = {"x_encoded": x_next, "intermediate_steps": inter_steps} + if return_intermediates: + out.update({"intermediates": intermediates}) + return x_next, out + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) + + @torch.no_grad() + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + callback=None, + ): + timesteps = ( + np.arange(self.ddpm_num_timesteps) + if use_original_steps + else self.ddim_timesteps + ) + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="Decoding image", total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full( + (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long + ) + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + if callback: + callback(i) + return x_dec diff --git a/audioldm2/latent_diffusion/models/ddpm.py b/audioldm2/latent_diffusion/models/ddpm.py new file mode 100755 index 0000000000000000000000000000000000000000..df3a6c032ba2ec61250212a31d68184e763dcf0e --- /dev/null +++ b/audioldm2/latent_diffusion/models/ddpm.py @@ -0,0 +1,1840 @@ +from multiprocessing.sharedctypes import Value +import os + +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid +from audioldm2.latent_diffusion.modules.encoders.modules import * + +from audioldm2.latent_diffusion.util import ( + exists, + default, + count_params, + instantiate_from_config, +) +from audioldm2.latent_diffusion.modules.ema import LitEma +from audioldm2.latent_diffusion.modules.distributions.distributions import ( + DiagonalGaussianDistribution, +) + +# from latent_encoder.autoencoder import ( +# VQModelInterface, +# IdentityFirstStage, +# AutoencoderKL, +# ) + +from audioldm2.latent_diffusion.modules.diffusionmodules.util import ( + make_beta_schedule, + extract_into_tensor, + noise_like, +) + +from audioldm2.latent_diffusion.models.ddim import DDIMSampler +from audioldm2.latent_diffusion.models.plms import PLMSSampler +import soundfile as sf +import os + +__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} + +CACHE_DIR = os.getenv( + "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm2") +) + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(nn.Module): + # classic DDPM with Gaussian diffusion, in image space + def __init__( + self, + unet_config, + sampling_rate=None, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + latent_t_size=256, + latent_f_size=16, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0.0, + evaluator=None, + device=None, + ): + super().__init__() + assert parameterization in [ + "eps", + "x0", + "v", + ], 'currently only supporting "eps" and "x0" and "v"' + self.parameterization = parameterization + self.state = None + self.device = device + # print( + # f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode" + # ) + assert sampling_rate is not None + self.validation_folder_name = "temp_name" + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.sampling_rate = sampling_rate + + self.clap = CLAPAudioEmbeddingClassifierFreev2( + pretrained_path="", + sampling_rate=self.sampling_rate, + embed_mode="audio", + amodel="HTSAT-base", + ) + + self.initialize_param_check_toolkit() + + self.latent_t_size = latent_t_size + self.latent_f_size = latent_f_size + + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt( + ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet + ) + + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + else: + self.logvar = nn.Parameter(self.logvar, requires_grad=False) + + self.logger_save_dir = None + self.logger_exp_name = None + self.logger_exp_group_name = None + self.logger_version = None + + self.label_indices_total = None + # To avoid the system cannot find metric value for checkpoint + self.metrics_buffer = { + "val/kullback_leibler_divergence_sigmoid": 15.0, + "val/kullback_leibler_divergence_softmax": 10.0, + "val/psnr": 0.0, + "val/ssim": 0.0, + "val/inception_score_mean": 1.0, + "val/inception_score_std": 0.0, + "val/kernel_inception_distance_mean": 0.0, + "val/kernel_inception_distance_std": 0.0, + "val/frechet_inception_distance": 133.0, + "val/frechet_audio_distance": 32.0, + } + self.initial_learning_rate = None + self.test_data_subset_path = None + + def get_log_dir(self): + return os.path.join( + self.logger_save_dir, self.logger_exp_group_name, self.logger_exp_name + ) + + def set_log_dir(self, save_dir, exp_group_name, exp_name): + self.logger_save_dir = save_dir + self.logger_exp_group_name = exp_group_name + self.logger_exp_name = exp_name + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped", + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1", + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + ) + self.register_buffer( + "posterior_mean_coef2", + to_torch( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + ) + + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) + elif self.parameterization == "x0": + lvlb_weights = ( + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + ) + elif self.parameterization == "v": + lvlb_weights = torch.ones_like( + self.betas**2 + / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) + ) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + # if context is not None: + # print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + # if context is not None: + # print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = ( + self.load_state_dict(sd, strict=False) + if not only_model + else self.model.load_state_dict(sd, strict=False) + ) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised + ) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = ( + (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous() + ) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm( + reversed(range(0, self.num_timesteps)), + desc="Sampling t", + total=self.num_timesteps, + ): + img = self.p_sample( + img, + torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised, + ) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + shape = (batch_size, channels, self.latent_t_size, self.latent_f_size) + self.channels + return self.p_sample_loop(shape, return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == "l1": + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == "l2": + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction="none") + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def predict_start_from_z_and_v(self, x_t, t, v): + # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) + * x_t + ) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint( + 0, self.num_timesteps, (x.shape[0],), device=self.device + ).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + # fbank, log_magnitudes_stft, label_indices, fname, waveform, clip_label, text = batch + # fbank, stft, label_indices, fname, waveform, text = batch + fname, text, waveform, stft, fbank = ( + batch["fname"], + batch["text"], + batch["waveform"], + batch["stft"], + batch["log_mel_spec"], + ) + # for i in range(fbank.size(0)): + # fb = fbank[i].numpy() + # seg_lb = seg_label[i].numpy() + # logits = np.mean(seg_lb, axis=0) + # index = np.argsort(logits)[::-1][:5] + # plt.imshow(seg_lb[:,index], aspect="auto") + # plt.title(index) + # plt.savefig("%s_label.png" % i) + # plt.close() + # plt.imshow(fb, aspect="auto") + # plt.savefig("%s_fb.png" % i) + # plt.close() + ret = {} + + ret["fbank"] = ( + fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float() + ) + ret["stft"] = stft.to(memory_format=torch.contiguous_format).float() + # ret["clip_label"] = clip_label.to(memory + # _format=torch.contiguous_format).float() + ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float() + ret["text"] = list(text) + ret["fname"] = fname + + for key in batch.keys(): + if key not in ret.keys(): + ret[key] = batch[key] + + return ret[k] + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, "n b c h w -> b n c h w") + denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample( + batch_size=N, return_intermediates=True + ) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + def initialize_param_check_toolkit(self): + self.tracked_steps = 0 + self.param_dict = {} + + def statistic_require_grad_tensor_number(self, module, name=None): + requires_grad_num = 0 + total_num = 0 + require_grad_tensor = None + for p in module.parameters(): + if p.requires_grad: + requires_grad_num += 1 + if require_grad_tensor is None: + require_grad_tensor = p + total_num += 1 + print( + "Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)" + % (name, requires_grad_num, total_num, requires_grad_num / total_num) + ) + return require_grad_tensor + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__( + self, + first_stage_config, + cond_stage_config=None, + num_timesteps_cond=None, + cond_stage_key="image", + optimize_ddpm_parameter=True, + unconditional_prob_cfg=0.1, + warmup_steps=10000, + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + batchsize=None, + evaluation_params={}, + scale_by_std=False, + base_learning_rate=None, + *args, + **kwargs, + ): + self.learning_rate = base_learning_rate + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + self.warmup_steps = warmup_steps + + if optimize_ddpm_parameter: + if unconditional_prob_cfg == 0.0: + "You choose to optimize DDPM. The classifier free guidance scale should be 0.1" + unconditional_prob_cfg = 0.1 + else: + if unconditional_prob_cfg == 0.1: + "You choose not to optimize DDPM. The classifier free guidance scale should be 0.0" + unconditional_prob_cfg = 0.0 + + self.evaluation_params = evaluation_params + assert self.num_timesteps_cond <= kwargs["timesteps"] + + # for backwards compatibility after implementation of DiffusionWrapper + # if conditioning_key is None: + # conditioning_key = "concat" if concat_mode else "crossattn" + # if cond_stage_config == "__is_unconditional__": + # conditioning_key = None + + conditioning_key = list(cond_stage_config.keys()) + + self.conditioning_key = conditioning_key + + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + + self.optimize_ddpm_parameter = optimize_ddpm_parameter + # if(not optimize_ddpm_parameter): + # print("Warning: Close the optimization of the latent diffusion model") + # for p in self.model.parameters(): + # p.requires_grad=False + + self.concat_mode = concat_mode + self.cond_stage_key = cond_stage_key + self.cond_stage_key_orig = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer("scale_factor", torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.unconditional_prob_cfg = unconditional_prob_cfg + self.cond_stage_models = nn.ModuleList([]) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + self.conditional_dry_run_finished = False + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + + for each in self.cond_stage_models: + params = params + list( + each.parameters() + ) # Add the parameter from the conditional stage + + if self.learn_logvar: + print("Diffusion model optimizing logvar") + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + # if self.use_scheduler: + # assert "target" in self.scheduler_config + # scheduler = instantiate_from_config(self.scheduler_config) + + # print("Setting up LambdaLR scheduler...") + # scheduler = [ + # { + # "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), + # "interval": "step", + # "frequency": 1, + # } + # ] + # return [opt], scheduler + return opt + + def make_cond_schedule( + self, + ): + self.cond_ids = torch.full( + size=(self.num_timesteps,), + fill_value=self.num_timesteps - 1, + dtype=torch.long, + ) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) + ).long() + self.cond_ids[: self.num_timesteps_cond] = ids + + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if ( + self.scale_factor == 1 + and self.scale_by_std + and self.current_epoch == 0 + and self.global_step == 0 + and batch_idx == 0 + and not self.restarted_from_ckpt + ): + # assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer("scale_factor", 1.0 / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + super().register_schedule( + given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s + ) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def make_decision(self, probability): + if float(torch.rand(1)) < probability: + return True + else: + return False + + def instantiate_cond_stage(self, config): + self.cond_stage_model_metadata = {} + for i, cond_model_key in enumerate(config.keys()): + model = instantiate_from_config(config[cond_model_key]) + self.cond_stage_models.append(model) + self.cond_stage_model_metadata[cond_model_key] = { + "model_idx": i, + "cond_stage_key": config[cond_model_key]["cond_stage_key"], + "conditioning_key": config[cond_model_key]["conditioning_key"], + } + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + def get_learned_conditioning(self, c, key, unconditional_cfg): + assert key in self.cond_stage_model_metadata.keys() + + # Classifier-free guidance + if not unconditional_cfg: + c = self.cond_stage_models[ + self.cond_stage_model_metadata[key]["model_idx"] + ](c) + else: + # when the cond_stage_key is "all", pick one random element out + if isinstance(c, dict): + c = c[list(c.keys())[0]] + + if isinstance(c, torch.Tensor): + batchsize = c.size(0) + elif isinstance(c, list): + batchsize = len(c) + else: + raise NotImplementedError() + + c = self.cond_stage_models[ + self.cond_stage_model_metadata[key]["model_idx"] + ].get_unconditional_condition(batchsize) + + return c + + def get_input( + self, + batch, + k, + return_first_stage_encode=True, + return_decoding_output=False, + return_encoder_input=False, + return_encoder_output=False, + unconditional_prob_cfg=0.1, + ): + x = super().get_input(batch, k) + + x = x.to(self.device) + + if return_first_stage_encode: + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + else: + z = None + cond_dict = {} + if len(self.cond_stage_model_metadata.keys()) > 0: + unconditional_cfg = False + if self.conditional_dry_run_finished and self.make_decision( + unconditional_prob_cfg + ): + unconditional_cfg = True + for cond_model_key in self.cond_stage_model_metadata.keys(): + cond_stage_key = self.cond_stage_model_metadata[cond_model_key][ + "cond_stage_key" + ] + + if cond_model_key in cond_dict.keys(): + continue + + if not self.training: + if isinstance( + self.cond_stage_models[ + self.cond_stage_model_metadata[cond_model_key]["model_idx"] + ], + CLAPAudioEmbeddingClassifierFreev2, + ): + print( + "Warning: CLAP model normally should use text for evaluation" + ) + + # The original data for conditioning + # If cond_model_key is "all", that means the conditional model need all the information from a batch + + if cond_stage_key != "all": + xc = super().get_input(batch, cond_stage_key) + if type(xc) == torch.Tensor: + xc = xc.to(self.device) + else: + xc = batch + + # if cond_stage_key is "all", xc will be a dictionary containing all keys + # Otherwise xc will be an entry of the dictionary + c = self.get_learned_conditioning( + xc, key=cond_model_key, unconditional_cfg=unconditional_cfg + ) + + # cond_dict will be used to condition the diffusion model + # If one conditional model return multiple conditioning signal + if isinstance(c, dict): + for k in c.keys(): + cond_dict[k] = c[k] + else: + cond_dict[cond_model_key] = c + + # If the key is accidently added to the dictionary and not in the condition list, remove the condition + # for k in list(cond_dict.keys()): + # if(k not in self.cond_stage_model_metadata.keys()): + # del cond_dict[k] + + out = [z, cond_dict] + + if return_decoding_output: + xrec = self.decode_first_stage(z) + out += [xrec] + + if return_encoder_input: + out += [x] + + if return_encoder_output: + out += [encoder_posterior] + + if not self.conditional_dry_run_finished: + self.conditional_dry_run_finished = True + + # Output is a dictionary, where the value could only be tensor or tuple + return out + + def decode_first_stage(self, z): + with torch.no_grad(): + z = 1.0 / self.scale_factor * z + decoding = self.first_stage_model.decode(z) + return decoding + + def mel_spectrogram_to_waveform( + self, mel, savepath=".", bs=None, name="outwav", save=True + ): + # Mel: [bs, 1, t-steps, fbins] + if len(mel.size()) == 4: + mel = mel.squeeze(1) + mel = mel.permute(0, 2, 1) + waveform = self.first_stage_model.vocoder(mel) + waveform = waveform.cpu().detach().numpy() + if save: + self.save_waveform(waveform, savepath, name) + return waveform + + def encode_first_stage(self, x): + with torch.no_grad(): + return self.first_stage_model.encode(x) + + def extract_possible_loss_in_cond_dict(self, cond_dict): + # This function enable the conditional module to return loss function that can optimize them + + assert isinstance(cond_dict, dict) + losses = {} + + for cond_key in cond_dict.keys(): + if "loss" in cond_key and "noncond" in cond_key: + assert cond_key not in losses.keys() + losses[cond_key] = cond_dict[cond_key] + + return losses + + def filter_useful_cond_dict(self, cond_dict): + new_cond_dict = {} + for key in cond_dict.keys(): + if key in self.cond_stage_model_metadata.keys(): + new_cond_dict[key] = cond_dict[key] + + # All the conditional key in the metadata should be used + for key in self.cond_stage_model_metadata.keys(): + assert key in new_cond_dict.keys(), "%s, %s" % ( + key, + str(new_cond_dict.keys()), + ) + + return new_cond_dict + + def shared_step(self, batch, **kwargs): + if self.training: + # Classifier-free guidance + unconditional_prob_cfg = self.unconditional_prob_cfg + else: + unconditional_prob_cfg = 0.0 # TODO possible bug here + + x, c = self.get_input( + batch, self.first_stage_key, unconditional_prob_cfg=unconditional_prob_cfg + ) + + if self.optimize_ddpm_parameter: + loss, loss_dict = self(x, self.filter_useful_cond_dict(c)) + else: + loss_dict = {} + loss = None + + additional_loss_for_cond_modules = self.extract_possible_loss_in_cond_dict(c) + assert isinstance(additional_loss_for_cond_modules, dict) + + loss_dict.update(additional_loss_for_cond_modules) + + if len(additional_loss_for_cond_modules.keys()) > 0: + for k in additional_loss_for_cond_modules.keys(): + if loss is None: + loss = additional_loss_for_cond_modules[k] + else: + loss = loss + additional_loss_for_cond_modules[k] + + # for k,v in additional_loss_for_cond_modules.items(): + # self.log( + # "cond_stage/"+k, + # float(v), + # prog_bar=True, + # logger=True, + # on_step=True, + # on_epoch=True, + # ) + if self.training: + assert loss is not None + + return loss, loss_dict + + def forward(self, x, c, *args, **kwargs): + t = torch.randint( + 0, self.num_timesteps, (x.shape[0],), device=self.device + ).long() + + # assert c is not None + # c = self.get_learned_conditioning(c) + + loss, loss_dict = self.p_losses(x, c, t, *args, **kwargs) + return loss, loss_dict + + def reorder_cond_dict(self, cond_dict): + # To make sure the order is correct + new_cond_dict = {} + for key in self.conditioning_key: + new_cond_dict[key] = cond_dict[key] + return new_cond_dict + + def apply_model(self, x_noisy, t, cond, return_ids=False): + cond = self.reorder_cond_dict(cond) + + x_recon = self.model(x_noisy, t, cond_dict=cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = "train" if self.training else "val" + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError() + # print(model_output.size(), target.size()) + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f"{prefix}/loss_gamma": loss.mean()}) + loss_dict.update({"logvar": self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f"{prefix}/loss_vlb": loss_vlb}) + loss += self.original_elbo_weight * loss_vlb + loss_dict.update({f"{prefix}/loss": loss}) + + return loss, loss_dict + + def p_mean_variance( + self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + ): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score( + self, model_out, x, t, c, **corrector_kwargs + ) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample( + self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + ): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance( + x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = ( + (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous() + ) + + # if return_codebook_ids: + # return model_mean + nonzero_mask * ( + # 0.5 * model_log_variance + # ).exp() * noise, logits.argmax(dim=1) + if return_x0: + return ( + model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, + x0, + ) + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising( + self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, + log_every_t=None, + ): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm( + reversed(range(0, timesteps)), + desc="Progressive Generation", + total=timesteps, + ) + if verbose + else reversed(range(0, timesteps)) + ) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != "hybrid" + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop( + self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + ): + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + + if self.shorten_cond_schedule: + assert self.model.conditioning_key != "hybrid" + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + ) + + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample( + self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs, + ): + if shape is None: + shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size) + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) + return self.p_sample_loop( + cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0, + **kwargs, + ) + + def save_waveform(self, waveform, savepath, name="outwav"): + for i in range(waveform.shape[0]): + if type(name) is str: + path = os.path.join( + savepath, "%s_%s_%s.wav" % (self.global_step, i, name) + ) + elif type(name) is list: + path = os.path.join( + savepath, + "%s.wav" + % ( + os.path.basename(name[i]) + if (not ".wav" in name[i]) + else os.path.basename(name[i]).split(".")[0] + ), + ) + else: + raise NotImplementedError + todo_waveform = waveform[i, 0] + todo_waveform = ( + todo_waveform / np.max(np.abs(todo_waveform)) + ) * 0.8 # Normalize the energy of the generation output + sf.write(path, todo_waveform, samplerate=self.sampling_rate) + + @torch.no_grad() + def sample_log( + self, + cond, + batch_size, + ddim, + ddim_steps, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_plms=False, + mask=None, + **kwargs, + ): + if mask is not None: + shape = (self.channels, mask.size()[-2], mask.size()[-1]) + else: + shape = (self.channels, self.latent_t_size, self.latent_f_size) + + intermediate = None + if ddim and not use_plms: + ddim_sampler = DDIMSampler(self) + samples, intermediates = ddim_sampler.sample( + ddim_steps, + batch_size, + shape, + cond, + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + mask=mask, + **kwargs, + ) + elif use_plms: + plms_sampler = PLMSSampler(self) + samples, intermediates = plms_sampler.sample( + ddim_steps, + batch_size, + shape, + cond, + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + mask=mask, + unconditional_conditioning=unconditional_conditioning, + **kwargs, + ) + + else: + samples, intermediates = self.sample( + cond=cond, + batch_size=batch_size, + return_intermediates=True, + unconditional_guidance_scale=unconditional_guidance_scale, + mask=mask, + unconditional_conditioning=unconditional_conditioning, + **kwargs, + ) + + return samples, intermediate + + @torch.no_grad() + def generate_batch( + self, + batch, + ddim_steps=200, + ddim_eta=1.0, + x_T=None, + n_gen=1, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_plms=False, + **kwargs, + ): + # Generate n_gen times and select the best + # Batch: audio, text, fnames + assert x_T is None + + if use_plms: + assert ddim_steps is not None + + use_ddim = ddim_steps is not None + + # with self.ema_scope("Plotting"): + for i in range(1): + z, c = self.get_input( + batch, + self.first_stage_key, + unconditional_prob_cfg=0.0, # Do not output unconditional information in the c + ) + + c = self.filter_useful_cond_dict(c) + + text = super().get_input(batch, "text") + + # Generate multiple samples + batch_size = z.shape[0] * n_gen + + # Generate multiple samples at a time and filter out the best + # The condition to the diffusion wrapper can have many format + for cond_key in c.keys(): + if isinstance(c[cond_key], list): + for i in range(len(c[cond_key])): + c[cond_key][i] = torch.cat([c[cond_key][i]] * n_gen, dim=0) + elif isinstance(c[cond_key], dict): + for k in c[cond_key].keys(): + c[cond_key][k] = torch.cat([c[cond_key][k]] * n_gen, dim=0) + else: + c[cond_key] = torch.cat([c[cond_key]] * n_gen, dim=0) + + text = text * n_gen + + if unconditional_guidance_scale != 1.0: + unconditional_conditioning = {} + for key in self.cond_stage_model_metadata: + model_idx = self.cond_stage_model_metadata[key]["model_idx"] + unconditional_conditioning[key] = self.cond_stage_models[ + model_idx + ].get_unconditional_condition(batch_size) + + fnames = list(super().get_input(batch, "fname")) + samples, _ = self.sample_log( + cond=c, + batch_size=batch_size, + x_T=x_T, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + use_plms=use_plms, + ) + + mel = self.decode_first_stage(samples) + + waveform = self.mel_spectrogram_to_waveform( + mel, savepath="", bs=None, name=fnames, save=False + ) + + if n_gen > 1: + best_index = [] + similarity = self.clap.cos_similarity( + torch.FloatTensor(waveform).squeeze(1), text + ) + for i in range(z.shape[0]): + candidates = similarity[i :: z.shape[0]] + max_index = torch.argmax(candidates).item() + best_index.append(i + max_index * z.shape[0]) + + waveform = waveform[best_index] + + print("Similarity between generated audio and text:") + print(' '.join('{:.2f}'.format(num) for num in similarity.detach().cpu().tolist())) + print("Choose the following indexes as the output:", best_index) + + return waveform + + @torch.no_grad() + def generate_sample( + self, + batchs, + ddim_steps=200, + ddim_eta=1.0, + x_T=None, + n_gen=1, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + name=None, + use_plms=False, + limit_num=None, + **kwargs, + ): + # Generate n_gen times and select the best + # Batch: audio, text, fnames + assert x_T is None + try: + batchs = iter(batchs) + except TypeError: + raise ValueError("The first input argument should be an iterable object") + + if use_plms: + assert ddim_steps is not None + + use_ddim = ddim_steps is not None + if name is None: + name = self.get_validation_folder_name() + + waveform_save_path = os.path.join(self.get_log_dir(), name) + os.makedirs(waveform_save_path, exist_ok=True) + print("Waveform save path: ", waveform_save_path) + + if ( + "audiocaps" in waveform_save_path + and len(os.listdir(waveform_save_path)) >= 964 + ): + print("The evaluation has already been done at %s" % waveform_save_path) + return waveform_save_path + + with self.ema_scope("Plotting"): + for i, batch in enumerate(batchs): + z, c = self.get_input( + batch, + self.first_stage_key, + unconditional_prob_cfg=0.0, # Do not output unconditional information in the c + ) + + if limit_num is not None and i * z.size(0) > limit_num: + break + + c = self.filter_useful_cond_dict(c) + + text = super().get_input(batch, "text") + + # Generate multiple samples + batch_size = z.shape[0] * n_gen + + # Generate multiple samples at a time and filter out the best + # The condition to the diffusion wrapper can have many format + for cond_key in c.keys(): + if isinstance(c[cond_key], list): + for i in range(len(c[cond_key])): + c[cond_key][i] = torch.cat([c[cond_key][i]] * n_gen, dim=0) + elif isinstance(c[cond_key], dict): + for k in c[cond_key].keys(): + c[cond_key][k] = torch.cat([c[cond_key][k]] * n_gen, dim=0) + else: + c[cond_key] = torch.cat([c[cond_key]] * n_gen, dim=0) + + text = text * n_gen + + if unconditional_guidance_scale != 1.0: + unconditional_conditioning = {} + for key in self.cond_stage_model_metadata: + model_idx = self.cond_stage_model_metadata[key]["model_idx"] + unconditional_conditioning[key] = self.cond_stage_models[ + model_idx + ].get_unconditional_condition(batch_size) + + fnames = list(super().get_input(batch, "fname")) + samples, _ = self.sample_log( + cond=c, + batch_size=batch_size, + x_T=x_T, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + use_plms=use_plms, + ) + + mel = self.decode_first_stage(samples) + + waveform = self.mel_spectrogram_to_waveform( + mel, savepath=waveform_save_path, bs=None, name=fnames, save=False + ) + + if n_gen > 1: + try: + best_index = [] + similarity = self.clap.cos_similarity( + torch.FloatTensor(waveform).squeeze(1), text + ) + for i in range(z.shape[0]): + candidates = similarity[i :: z.shape[0]] + max_index = torch.argmax(candidates).item() + best_index.append(i + max_index * z.shape[0]) + + waveform = waveform[best_index] + + print("Similarity between generated audio and text", similarity) + print("Choose the following indexes:", best_index) + except Exception as e: + print("Warning: while calculating CLAP score (not fatal), ", e) + self.save_waveform(waveform, waveform_save_path, name=fnames) + return waveform_save_path + + +class DiffusionWrapper(nn.Module): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + + self.conditioning_key = conditioning_key + + for key in self.conditioning_key: + if ( + "concat" in key + or "crossattn" in key + or "hybrid" in key + or "film" in key + or "noncond" in key + ): + continue + else: + raise Value("The conditioning key %s is illegal" % key) + + self.being_verbosed_once = False + + def forward(self, x, t, cond_dict: dict = {}): + x = x.contiguous() + t = t.contiguous() + + # x with condition (or maybe not) + xc = x + + y = None + context_list, attn_mask_list = [], [] + + conditional_keys = cond_dict.keys() + + for key in conditional_keys: + if "concat" in key: + xc = torch.cat([x, cond_dict[key].unsqueeze(1)], dim=1) + elif "film" in key: + if y is None: + y = cond_dict[key].squeeze(1) + else: + y = torch.cat([y, cond_dict[key].squeeze(1)], dim=-1) + elif "crossattn" in key: + # assert context is None, "You can only have one context matrix, got %s" % (cond_dict.keys()) + if isinstance(cond_dict[key], dict): + for k in cond_dict[key].keys(): + if "crossattn" in k: + context, attn_mask = cond_dict[key][ + k + ] # crossattn_audiomae_pooled: torch.Size([12, 128, 768]) + else: + assert len(cond_dict[key]) == 2, ( + "The context condition for %s you returned should have two element, one context one mask" + % (key) + ) + context, attn_mask = cond_dict[key] + + # The input to the UNet model is a list of context matrix + context_list.append(context) + attn_mask_list.append(attn_mask) + + elif ( + "noncond" in key + ): # If you use loss function in the conditional module, include the keyword "noncond" in the return dictionary + continue + else: + raise NotImplementedError() + + # if(not self.being_verbosed_once): + # print("The input shape to the diffusion model is as follows:") + # print("xc", xc.size()) + # print("t", t.size()) + # for i in range(len(context_list)): + # print("context_%s" % i, context_list[i].size(), attn_mask_list[i].size()) + # if(y is not None): + # print("y", y.size()) + # self.being_verbosed_once = True + out = self.diffusion_model( + xc, t, context_list=context_list, y=y, context_attn_mask_list=attn_mask_list + ) + return out + self.warmup_step() + + if ( + self.state is None + and len(self.trainer.optimizers[0].state_dict()["state"].keys()) > 0 + ): + self.state = ( + self.trainer.optimizers[0].state_dict()["state"][0]["exp_avg"].clone() + ) + elif self.state is not None and batch_idx % 1000 == 0: + assert ( + torch.sum( + torch.abs( + self.state + - self.trainer.optimizers[0].state_dict()["state"][0]["exp_avg"] + ) + ) + > 1e-7 + ), "Optimizer is not working" + + if len(self.metrics_buffer.keys()) > 0: + for k in self.metrics_buffer.keys(): + self.log( + k, + self.metrics_buffer[k], + prog_bar=False, + logger=True, + on_step=True, + on_epoch=False, + ) + print(k, self.metrics_buffer[k]) + self.metrics_buffer = {} + + loss, loss_dict = self.shared_step(batch) + + self.log_dict( + {k: float(v) for k, v in loss_dict.items()}, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + + self.log( + "global_step", + float(self.global_step), + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, + ) + + lr = self.trainer.optimizers[0].param_groups[0]["lr"] + self.log( + "lr_abs", + float(lr), + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, + ) + + +if __name__ == "__main__": + import yaml + + model_config = "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/stable-diffusion/models/ldm/text2img256/config.yaml" + model_config = yaml.load(open(model_config, "r"), Loader=yaml.FullLoader) + + latent_diffusion = LatentDiffusion(**model_config["model"]["params"]) + + import ipdb + + ipdb.set_trace() diff --git a/audioldm2/latent_diffusion/models/plms.py b/audioldm2/latent_diffusion/models/plms.py new file mode 100755 index 0000000000000000000000000000000000000000..9c80796442bd653ac3dc1970c12f621068a4d821 --- /dev/null +++ b/audioldm2/latent_diffusion/models/plms.py @@ -0,0 +1,360 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm + +from audioldm2.latent_diffusion.modules.diffusionmodules.util import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, +) + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + if ddim_eta != 0: + ddim_eta = 0 + # raise ValueError('ddim_eta must be 0 for PLMS') + + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f"Data shape for PLMS sampling is {size}") + + samples, intermediates = self.plms_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( + list(reversed(range(0, timesteps))) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full( + (b,), + time_range[min(i + 1, len(time_range) - 1)], + device=device, + dtype=torch.long, + ) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample( + x0, ts + ) # TODO: deterministic forward pass? + img = img_orig * mask + (1.0 - mask) * img + + outs = self.p_sample_plms( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, + t_next=ts_next, + ) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + old_eps=None, + t_next=None, + ): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if ( + unconditional_conditioning is None + or unconditional_guidance_scale == 1.0 + ): + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score( + self.model, e_t, x, t, c, **corrector_kwargs + ) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = ( + 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] + ) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/audioldm2/latent_diffusion/modules/__init__.py b/audioldm2/latent_diffusion/modules/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm2/latent_diffusion/modules/attention.py b/audioldm2/latent_diffusion/modules/attention.py new file mode 100755 index 0000000000000000000000000000000000000000..6116342da98249c681ddb5f696b48dc0f5ac69f2 --- /dev/null +++ b/audioldm2/latent_diffusion/modules/attention.py @@ -0,0 +1,467 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from audioldm2.latent_diffusion.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +# class CrossAttention(nn.Module): +# """ +# ### Cross Attention Layer +# This falls-back to self-attention when conditional embeddings are not specified. +# """ + +# use_flash_attention: bool = True + +# # use_flash_attention: bool = False +# def __init__( +# self, +# query_dim, +# context_dim=None, +# heads=8, +# dim_head=64, +# dropout=0.0, +# is_inplace: bool = True, +# ): +# # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True): +# """ +# :param d_model: is the input embedding size +# :param n_heads: is the number of attention heads +# :param d_head: is the size of a attention head +# :param d_cond: is the size of the conditional embeddings +# :param is_inplace: specifies whether to perform the attention softmax computation inplace to +# save memory +# """ +# super().__init__() + +# self.is_inplace = is_inplace +# self.n_heads = heads +# self.d_head = dim_head + +# # Attention scaling factor +# self.scale = dim_head**-0.5 + +# # The normal self-attention layer +# if context_dim is None: +# context_dim = query_dim + +# # Query, key and value mappings +# d_attn = dim_head * heads +# self.to_q = nn.Linear(query_dim, d_attn, bias=False) +# self.to_k = nn.Linear(context_dim, d_attn, bias=False) +# self.to_v = nn.Linear(context_dim, d_attn, bias=False) + +# # Final linear layer +# self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout)) + +# # Setup [flash attention](https://github.com/HazyResearch/flash-attention). +# # Flash attention is only used if it's installed +# # and `CrossAttention.use_flash_attention` is set to `True`. +# try: +# # You can install flash attention by cloning their Github repo, +# # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention) +# # and then running `python setup.py install` +# from flash_attn.flash_attention import FlashAttention + +# self.flash = FlashAttention() +# # Set the scale for scaled dot-product attention. +# self.flash.softmax_scale = self.scale +# # Set to `None` if it's not installed +# except ImportError: +# self.flash = None + +# def forward(self, x, context=None, mask=None): +# """ +# :param x: are the input embeddings of shape `[batch_size, height * width, d_model]` +# :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` +# """ + +# # If `cond` is `None` we perform self attention +# has_cond = context is not None +# if not has_cond: +# context = x + +# # Get query, key and value vectors +# q = self.to_q(x) +# k = self.to_k(context) +# v = self.to_v(context) + +# # Use flash attention if it's available and the head size is less than or equal to `128` +# if ( +# CrossAttention.use_flash_attention +# and self.flash is not None +# and not has_cond +# and self.d_head <= 128 +# ): +# return self.flash_attention(q, k, v) +# # Otherwise, fallback to normal attention +# else: +# return self.normal_attention(q, k, v) + +# def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): +# """ +# #### Flash Attention +# :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# """ + +# # Get batch size and number of elements along sequence axis (`width * height`) +# batch_size, seq_len, _ = q.shape + +# # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of +# # shape `[batch_size, seq_len, 3, n_heads * d_head]` +# qkv = torch.stack((q, k, v), dim=2) +# # Split the heads +# qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head) + +# # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to +# # fit this size. +# if self.d_head <= 32: +# pad = 32 - self.d_head +# elif self.d_head <= 64: +# pad = 64 - self.d_head +# elif self.d_head <= 128: +# pad = 128 - self.d_head +# else: +# raise ValueError(f"Head size ${self.d_head} too large for Flash Attention") + +# # Pad the heads +# if pad: +# qkv = torch.cat( +# (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1 +# ) + +# # Compute attention +# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ +# # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]` +# # TODO here I add the dtype changing +# out, _ = self.flash(qkv.type(torch.float16)) +# # Truncate the extra head size +# out = out[:, :, :, : self.d_head].float() +# # Reshape to `[batch_size, seq_len, n_heads * d_head]` +# out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head) + +# # Map to `[batch_size, height * width, d_model]` with a linear layer +# return self.to_out(out) + +# def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): +# """ +# #### Normal Attention + +# :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# """ + +# # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]` +# q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32] +# k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32] +# v = v.view(*v.shape[:2], self.n_heads, -1) + +# # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$ +# attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale + +# # Compute softmax +# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$ +# if self.is_inplace: +# half = attn.shape[0] // 2 +# attn[half:] = attn[half:].softmax(dim=-1) +# attn[:half] = attn[:half].softmax(dim=-1) +# else: +# attn = attn.softmax(dim=-1) + +# # Compute attention output +# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ +# # attn: [bs, 20, 64, 1] +# # v: [bs, 1, 20, 32] +# out = torch.einsum("bhij,bjhd->bihd", attn, v) +# # Reshape to `[batch_size, height * width, n_heads * d_head]` +# out = out.reshape(*out.shape[:2], -1) +# # Map to `[batch_size, height * width, d_model]` with a linear layer +# return self.to_out(out) + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, "b j -> (b h) () j", h=h) + sim.masked_fill_(~(mask == 1), max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None, mask=None): + if context is None: + return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint) + else: + return checkpoint( + self._forward, (x, context, mask), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None, mask=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context, mask=mask) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + ): + super().__init__() + + context_dim = context_dim + + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim + ) + for d in range(depth) + ] + ) + + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, x, context=None, mask=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + for block in self.transformer_blocks: + x = block(x, context=context, mask=mask) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + return x + x_in diff --git a/audioldm2/latent_diffusion/modules/audiomae/AudioMAE.py b/audioldm2/latent_diffusion/modules/audiomae/AudioMAE.py new file mode 100755 index 0000000000000000000000000000000000000000..f02fa05e163076641b92bbeabceb5f89edb0f18e --- /dev/null +++ b/audioldm2/latent_diffusion/modules/audiomae/AudioMAE.py @@ -0,0 +1,149 @@ +""" +Reference Repo: https://github.com/facebookresearch/AudioMAE +""" + +import torch +import torch.nn as nn +from timm.models.layers import to_2tuple +import audioldm2.latent_diffusion.modules.audiomae.models_vit as models_vit +import audioldm2.latent_diffusion.modules.audiomae.models_mae as models_mae + +# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128)) + + +class PatchEmbed_new(nn.Module): + """Flexible Image to Patch Embedding""" + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10 + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + stride = to_2tuple(stride) + + self.img_size = img_size + self.patch_size = patch_size + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=stride + ) # with overlapped patches + # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) + # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w + self.patch_hw = (h, w) + self.num_patches = h * w + + def get_output_shape(self, img_size): + # todo: don't be lazy.. + return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class AudioMAE(nn.Module): + """Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)""" + + def __init__( + self, + ): + super().__init__() + model = models_vit.__dict__["vit_base_patch16"]( + num_classes=527, + drop_path_rate=0.1, + global_pool=True, + mask_2d=True, + use_custom_patch=False, + ) + + img_size = (1024, 128) + emb_dim = 768 + + model.patch_embed = PatchEmbed_new( + img_size=img_size, + patch_size=(16, 16), + in_chans=1, + embed_dim=emb_dim, + stride=16, + ) + num_patches = model.patch_embed.num_patches + # num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8 + model.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False + ) # fixed sin-cos embedding + + # checkpoint_path = '/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth' + # checkpoint = torch.load(checkpoint_path, map_location='cpu') + # msg = model.load_state_dict(checkpoint['model'], strict=False) + # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}') + + self.model = model + + def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0): + """ + x: mel fbank [Batch, 1, T, F] + mask_t_prob: 'T masking ratio (percentage of removed patches).' + mask_f_prob: 'F masking ratio (percentage of removed patches).' + """ + return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob) + + +class Vanilla_AudioMAE(nn.Module): + """Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM2)""" + + def __init__( + self, + ): + super().__init__() + model = models_mae.__dict__["mae_vit_base_patch16"]( + in_chans=1, audio_exp=True, img_size=(1024, 128) + ) + + # checkpoint_path = '/mnt/bn/lqhaoheliu/exps/checkpoints/audiomae/pretrained.pth' + # checkpoint = torch.load(checkpoint_path, map_location='cpu') + # msg = model.load_state_dict(checkpoint['model'], strict=False) + + # Skip the missing keys of decoder modules (not required) + # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}') + + self.model = model.eval() + + def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False): + """ + x: mel fbank [Batch, 1, 1024 (T), 128 (F)] + mask_ratio: 'masking ratio (percentage of removed patches).' + """ + with torch.no_grad(): + # embed: [B, 513, 768] for mask_ratio=0.0 + if no_mask: + if no_average: + raise RuntimeError("This function is deprecated") + embed = self.model.forward_encoder_no_random_mask_no_average( + x + ) # mask_ratio + else: + embed = self.model.forward_encoder_no_mask(x) # mask_ratio + else: + raise RuntimeError("This function is deprecated") + embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio) + return embed + + +if __name__ == "__main__": + model = Vanilla_AudioMAE().cuda() + input = torch.randn(4, 1, 1024, 128).cuda() + print("The first run") + embed = model(input, mask_ratio=0.0, no_mask=True) + print(embed) + print("The second run") + embed = model(input, mask_ratio=0.0) + print(embed) diff --git a/audioldm2/latent_diffusion/modules/audiomae/__init__.py b/audioldm2/latent_diffusion/modules/audiomae/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm2/latent_diffusion/modules/audiomae/models_mae.py b/audioldm2/latent_diffusion/modules/audiomae/models_mae.py new file mode 100755 index 0000000000000000000000000000000000000000..7ab0076710a08a7451dd4096bd6eb2f8f6e641aa --- /dev/null +++ b/audioldm2/latent_diffusion/modules/audiomae/models_mae.py @@ -0,0 +1,613 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +from functools import partial + +import torch +import torch.nn as nn + +from timm.models.vision_transformer import Block +from audioldm2.latent_diffusion.modules.audiomae.util.pos_embed import ( + get_2d_sincos_pos_embed, + get_2d_sincos_pos_embed_flexible, +) +from audioldm2.latent_diffusion.modules.audiomae.util.patch_embed import ( + PatchEmbed_new, + PatchEmbed_org, +) + + +class MaskedAutoencoderViT(nn.Module): + """Masked Autoencoder with VisionTransformer backbone""" + + def __init__( + self, + img_size=224, + patch_size=16, + stride=10, + in_chans=3, + embed_dim=1024, + depth=24, + num_heads=16, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4.0, + norm_layer=nn.LayerNorm, + norm_pix_loss=False, + audio_exp=False, + alpha=0.0, + temperature=0.2, + mode=0, + contextual_depth=8, + use_custom_patch=False, + split_pos=False, + pos_trainable=False, + use_nce=False, + beta=4.0, + decoder_mode=0, + mask_t_prob=0.6, + mask_f_prob=0.5, + mask_2d=False, + epoch=0, + no_shift=False, + ): + super().__init__() + + self.audio_exp = audio_exp + self.embed_dim = embed_dim + self.decoder_embed_dim = decoder_embed_dim + # -------------------------------------------------------------------------- + # MAE encoder specifics + if use_custom_patch: + print( + f"Use custom patch_emb with patch size: {patch_size}, stride: {stride}" + ) + self.patch_embed = PatchEmbed_new( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + stride=stride, + ) + else: + self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim) + self.use_custom_patch = use_custom_patch + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + # self.split_pos = split_pos # not useful + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable + ) # fixed sin-cos embedding + + self.encoder_depth = depth + self.contextual_depth = contextual_depth + self.blocks = nn.ModuleList( + [ + Block( + embed_dim, + num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + ) # qk_scale=None + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) + + # -------------------------------------------------------------------------- + # MAE decoder specifics + self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, decoder_embed_dim), + requires_grad=pos_trainable, + ) # fixed sin-cos embedding + + self.no_shift = no_shift + + self.decoder_mode = decoder_mode + if ( + self.use_custom_patch + ): # overlapped patches as in AST. Similar performance yet compute heavy + window_size = (6, 6) + feat_size = (102, 12) + else: + window_size = (4, 4) + feat_size = (64, 8) + if self.decoder_mode == 1: + decoder_modules = [] + for index in range(16): + if self.no_shift: + shift_size = (0, 0) + else: + if (index % 2) == 0: + shift_size = (0, 0) + else: + shift_size = (2, 0) + # shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]) + decoder_modules.append( + SwinTransformerBlock( + dim=decoder_embed_dim, + num_heads=16, + feat_size=feat_size, + window_size=window_size, + shift_size=shift_size, + mlp_ratio=mlp_ratio, + drop=0.0, + drop_attn=0.0, + drop_path=0.0, + extra_norm=False, + sequential_attn=False, + norm_layer=norm_layer, # nn.LayerNorm, + ) + ) + self.decoder_blocks = nn.ModuleList(decoder_modules) + else: + # Transfomer + self.decoder_blocks = nn.ModuleList( + [ + Block( + decoder_embed_dim, + decoder_num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + ) # qk_scale=None, + for i in range(decoder_depth) + ] + ) + + self.decoder_norm = norm_layer(decoder_embed_dim) + self.decoder_pred = nn.Linear( + decoder_embed_dim, patch_size**2 * in_chans, bias=True + ) # decoder to patch + + # -------------------------------------------------------------------------- + + self.norm_pix_loss = norm_pix_loss + + self.patch_size = patch_size + self.stride = stride + + # audio exps + self.alpha = alpha + self.T = temperature + self.mode = mode + self.use_nce = use_nce + self.beta = beta + + self.log_softmax = nn.LogSoftmax(dim=-1) + + self.mask_t_prob = mask_t_prob + self.mask_f_prob = mask_f_prob + self.mask_2d = mask_2d + + self.epoch = epoch + + self.initialize_weights() + + def initialize_weights(self): + # initialization + # initialize (and freeze) pos_embed by sin-cos embedding + if self.audio_exp: + pos_embed = get_2d_sincos_pos_embed_flexible( + self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True + ) + else: + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + int(self.patch_embed.num_patches**0.5), + cls_token=True, + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + if self.audio_exp: + decoder_pos_embed = get_2d_sincos_pos_embed_flexible( + self.decoder_pos_embed.shape[-1], + self.patch_embed.patch_hw, + cls_token=True, + ) + else: + decoder_pos_embed = get_2d_sincos_pos_embed( + self.decoder_pos_embed.shape[-1], + int(self.patch_embed.num_patches**0.5), + cls_token=True, + ) + self.decoder_pos_embed.data.copy_( + torch.from_numpy(decoder_pos_embed).float().unsqueeze(0) + ) + + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + torch.nn.init.normal_(self.cls_token, std=0.02) + torch.nn.init.normal_(self.mask_token, std=0.02) + + # initialize nn.Linear and nn.LayerNorm + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + L = (H/p)*(W/p) + """ + p = self.patch_embed.patch_size[0] + # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + if self.audio_exp: + if self.use_custom_patch: # overlapped patch + h, w = self.patch_embed.patch_hw + # todo: fixed h/w patch size and stride size. Make hw custom in the future + x = imgs.unfold(2, self.patch_size, self.stride).unfold( + 3, self.patch_size, self.stride + ) # n,1,H,W -> n,1,h,w,p,p + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) + # x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p)) + # x = torch.einsum('nchpwq->nhwpqc', x) + # x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) + else: + h = imgs.shape[2] // p + w = imgs.shape[3] // p + # h,w = self.patch_embed.patch_hw + x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p)) + x = torch.einsum("nchpwq->nhwpqc", x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) + else: + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum("nchpwq->nhwpqc", x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + + return x + + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + specs: (N, 1, H, W) + """ + p = self.patch_embed.patch_size[0] + h = 1024 // p + w = 128 // p + x = x.reshape(shape=(x.shape[0], h, w, p, p, 1)) + x = torch.einsum("nhwpqc->nchpwq", x) + specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p)) + return specs + + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1 + ) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def random_masking_2d(self, x, mask_t_prob, mask_f_prob): + """ + 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob) + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + if self.use_custom_patch: # overlapped patch + T = 101 + F = 12 + else: + T = 64 + F = 8 + # x = x.reshape(N, T, F, D) + len_keep_t = int(T * (1 - mask_t_prob)) + len_keep_f = int(F * (1 - mask_f_prob)) + + # noise for mask in time + noise_t = torch.rand(N, T, device=x.device) # noise in [0, 1] + # sort noise for each sample aling time + ids_shuffle_t = torch.argsort( + noise_t, dim=1 + ) # ascend: small is keep, large is remove + ids_restore_t = torch.argsort(ids_shuffle_t, dim=1) + ids_keep_t = ids_shuffle_t[:, :len_keep_t] + # noise mask in freq + noise_f = torch.rand(N, F, device=x.device) # noise in [0, 1] + ids_shuffle_f = torch.argsort( + noise_f, dim=1 + ) # ascend: small is keep, large is remove + ids_restore_f = torch.argsort(ids_shuffle_f, dim=1) + ids_keep_f = ids_shuffle_f[:, :len_keep_f] # + + # generate the binary mask: 0 is keep, 1 is remove + # mask in freq + mask_f = torch.ones(N, F, device=x.device) + mask_f[:, :len_keep_f] = 0 + mask_f = ( + torch.gather(mask_f, dim=1, index=ids_restore_f) + .unsqueeze(1) + .repeat(1, T, 1) + ) # N,T,F + # mask in time + mask_t = torch.ones(N, T, device=x.device) + mask_t[:, :len_keep_t] = 0 + mask_t = ( + torch.gather(mask_t, dim=1, index=ids_restore_t) + .unsqueeze(1) + .repeat(1, F, 1) + .permute(0, 2, 1) + ) # N,T,F + mask = 1 - (1 - mask_t) * (1 - mask_f) # N, T, F + + # get masked x + id2res = torch.Tensor(list(range(N * T * F))).reshape(N, T, F).to(x.device) + id2res = id2res + 999 * mask # add a large value for masked elements + id2res2 = torch.argsort(id2res.flatten(start_dim=1)) + ids_keep = id2res2.flatten(start_dim=1)[:, : len_keep_f * len_keep_t] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + ids_restore = torch.argsort(id2res2.flatten(start_dim=1)) + mask = mask.flatten(start_dim=1) + + return x_masked, mask, ids_restore + + def forward_encoder(self, x, mask_ratio, mask_2d=False): + # embed patches + x = self.patch_embed(x) + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + if mask_2d: + x, mask, ids_restore = self.random_masking_2d( + x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob + ) + else: + x, mask, ids_restore = self.random_masking(x, mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + + return x, mask, ids_restore, None + + def forward_encoder_no_random_mask_no_average(self, x): + # embed patches + x = self.patch_embed(x) + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + # if mask_2d: + # x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob) + # else: + # x, mask, ids_restore = self.random_masking(x, mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + + return x + + def forward_encoder_no_mask(self, x): + # embed patches + x = self.patch_embed(x) + + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + # x, mask, ids_restore = self.random_masking(x, mask_ratio) + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + contextual_embs = [] + for n, blk in enumerate(self.blocks): + x = blk(x) + if n > self.contextual_depth: + contextual_embs.append(self.norm(x)) + # x = self.norm(x) + contextual_emb = torch.stack(contextual_embs, dim=0).mean(dim=0) + + return contextual_emb + + def forward_decoder(self, x, ids_restore): + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat( + x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1 + ) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token + x_ = torch.gather( + x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) + ) # unshuffle + x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token + + # add pos embed + x = x + self.decoder_pos_embed + + if self.decoder_mode != 0: + B, L, D = x.shape + x = x[:, 1:, :] + if self.use_custom_patch: + x = x.reshape(B, 101, 12, D) + x = torch.cat([x, x[:, -1, :].unsqueeze(1)], dim=1) # hack + x = x.reshape(B, 1224, D) + if self.decoder_mode > 3: # mvit + x = self.decoder_blocks(x) + else: + # apply Transformer blocks + for blk in self.decoder_blocks: + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + pred = self.decoder_pred(x) + + # remove cls token + if self.decoder_mode != 0: + if self.use_custom_patch: + pred = pred.reshape(B, 102, 12, 256) + pred = pred[:, :101, :, :] + pred = pred.reshape(B, 1212, 256) + else: + pred = pred + else: + pred = pred[:, 1:, :] + return pred, None, None # emb, emb_pixel + + def forward_loss(self, imgs, pred, mask, norm_pix_loss=False): + """ + imgs: [N, 3, H, W] + pred: [N, L, p*p*3] + mask: [N, L], 0 is keep, 1 is remove, + """ + target = self.patchify(imgs) + if norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.0e-6) ** 0.5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + + loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + return loss + + def forward(self, imgs, mask_ratio=0.8): + emb_enc, mask, ids_restore, _ = self.forward_encoder( + imgs, mask_ratio, mask_2d=self.mask_2d + ) + pred, _, _ = self.forward_decoder(emb_enc, ids_restore) # [N, L, p*p*3] + loss_recon = self.forward_loss( + imgs, pred, mask, norm_pix_loss=self.norm_pix_loss + ) + loss_contrastive = torch.FloatTensor([0.0]).cuda() + return loss_recon, pred, mask, loss_contrastive + + +def mae_vit_small_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + decoder_embed_dim=512, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def mae_vit_base_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + decoder_embed_dim=512, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def mae_vit_large_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + decoder_embed_dim=512, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def mae_vit_huge_patch14_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + decoder_embed_dim=512, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +# set recommended archs +mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks diff --git a/audioldm2/latent_diffusion/modules/audiomae/models_vit.py b/audioldm2/latent_diffusion/modules/audiomae/models_vit.py new file mode 100755 index 0000000000000000000000000000000000000000..cb37adbc16cfb9a232493c473c9400f199655b6c --- /dev/null +++ b/audioldm2/latent_diffusion/modules/audiomae/models_vit.py @@ -0,0 +1,243 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +from functools import partial + +import torch +import torch.nn as nn +import timm.models.vision_transformer + + +class VisionTransformer(timm.models.vision_transformer.VisionTransformer): + """Vision Transformer with support for global average pooling""" + + def __init__( + self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs + ): + super(VisionTransformer, self).__init__(**kwargs) + + self.global_pool = global_pool + if self.global_pool: + norm_layer = kwargs["norm_layer"] + embed_dim = kwargs["embed_dim"] + self.fc_norm = norm_layer(embed_dim) + del self.norm # remove the original norm + self.mask_2d = mask_2d + self.use_custom_patch = use_custom_patch + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed[:, 1:, :] + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + if self.global_pool: + x = x[:, 1:, :].mean(dim=1) # global pool without cls token + outcome = self.fc_norm(x) + else: + x = self.norm(x) + outcome = x[:, 0] + + return outcome + + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1 + ) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def random_masking_2d(self, x, mask_t_prob, mask_f_prob): + """ + 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob) + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + + N, L, D = x.shape # batch, length, dim + if self.use_custom_patch: + # # for AS + T = 101 # 64,101 + F = 12 # 8,12 + # # for ESC + # T=50 + # F=12 + # for SPC + # T=12 + # F=12 + else: + # ## for AS + T = 64 + F = 8 + # ## for ESC + # T=32 + # F=8 + ## for SPC + # T=8 + # F=8 + + # mask T + x = x.reshape(N, T, F, D) + len_keep_T = int(T * (1 - mask_t_prob)) + noise = torch.rand(N, T, device=x.device) # noise in [0, 1] + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1 + ) # ascend: small is keep, large is remove + ids_keep = ids_shuffle[:, :len_keep_T] + index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D) + # x_masked = torch.gather(x, dim=1, index=index) + # x_masked = x_masked.reshape(N,len_keep_T*F,D) + x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D + + # mask F + # x = x.reshape(N, T, F, D) + x = x.permute(0, 2, 1, 3) # N T' F D => N F T' D + len_keep_F = int(F * (1 - mask_f_prob)) + noise = torch.rand(N, F, device=x.device) # noise in [0, 1] + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1 + ) # ascend: small is keep, large is remove + ids_keep = ids_shuffle[:, :len_keep_F] + # index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D) + index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D) + x_masked = torch.gather(x, dim=1, index=index) + x_masked = x_masked.permute(0, 2, 1, 3) # N F' T' D => N T' F' D + # x_masked = x_masked.reshape(N,len_keep*T,D) + x_masked = x_masked.reshape(N, len_keep_F * len_keep_T, D) + + return x_masked, None, None + + def forward_features_mask(self, x, mask_t_prob, mask_f_prob): + B = x.shape[0] # 4,1,1024,128 + x = self.patch_embed(x) # 4, 512, 768 + + x = x + self.pos_embed[:, 1:, :] + if self.random_masking_2d: + x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob) + else: + x, mask, ids_restore = self.random_masking(x, mask_t_prob) + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = self.pos_drop(x) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + + if self.global_pool: + x = x[:, 1:, :].mean(dim=1) # global pool without cls token + outcome = self.fc_norm(x) + else: + x = self.norm(x) + outcome = x[:, 0] + + return outcome + + # overwrite original timm + def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0): + if mask_t_prob > 0.0 or mask_f_prob > 0.0: + x = self.forward_features_mask( + x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob + ) + else: + x = self.forward_features(x) + x = self.head(x) + return x + + +def vit_small_patch16(**kwargs): + model = VisionTransformer( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + return model + + +def vit_base_patch16(**kwargs): + model = VisionTransformer( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + return model + + +def vit_large_patch16(**kwargs): + model = VisionTransformer( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + return model + + +def vit_huge_patch14(**kwargs): + model = VisionTransformer( + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + return model diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/crop.py b/audioldm2/latent_diffusion/modules/audiomae/util/crop.py new file mode 100755 index 0000000000000000000000000000000000000000..525e3c783c3d348e593dc89c2b5fb8520918e9ea --- /dev/null +++ b/audioldm2/latent_diffusion/modules/audiomae/util/crop.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch + +from torchvision import transforms +from torchvision.transforms import functional as F + + +class RandomResizedCrop(transforms.RandomResizedCrop): + """ + RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. + This may lead to results different with torchvision's version. + Following BYOL's TF code: + https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 + """ + + @staticmethod + def get_params(img, scale, ratio): + width, height = F._get_image_size(img) + area = height * width + + target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + log_ratio = torch.log(torch.tensor(ratio)) + aspect_ratio = torch.exp( + torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) + ).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + w = min(w, width) + h = min(h, height) + + i = torch.randint(0, height - h + 1, size=(1,)).item() + j = torch.randint(0, width - w + 1, size=(1,)).item() + + return i, j, h, w diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/datasets.py b/audioldm2/latent_diffusion/modules/audiomae/util/datasets.py new file mode 100755 index 0000000000000000000000000000000000000000..b90f89a7d5f78c31bc9113dd88b632b0c234f10a --- /dev/null +++ b/audioldm2/latent_diffusion/modules/audiomae/util/datasets.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +import os +import PIL + +from torchvision import datasets, transforms + +from timm.data import create_transform +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def build_dataset(is_train, args): + transform = build_transform(is_train, args) + + root = os.path.join(args.data_path, "train" if is_train else "val") + dataset = datasets.ImageFolder(root, transform=transform) + + print(dataset) + + return dataset + + +def build_transform(is_train, args): + mean = IMAGENET_DEFAULT_MEAN + std = IMAGENET_DEFAULT_STD + # train transform + if is_train: + # this should always dispatch to transforms_imagenet_train + transform = create_transform( + input_size=args.input_size, + is_training=True, + color_jitter=args.color_jitter, + auto_augment=args.aa, + interpolation="bicubic", + re_prob=args.reprob, + re_mode=args.remode, + re_count=args.recount, + mean=mean, + std=std, + ) + return transform + + # eval transform + t = [] + if args.input_size <= 224: + crop_pct = 224 / 256 + else: + crop_pct = 1.0 + size = int(args.input_size / crop_pct) + t.append( + transforms.Resize( + size, interpolation=PIL.Image.BICUBIC + ), # to maintain same ratio w.r.t. 224 images + ) + t.append(transforms.CenterCrop(args.input_size)) + + t.append(transforms.ToTensor()) + t.append(transforms.Normalize(mean, std)) + return transforms.Compose(t) diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/lars.py b/audioldm2/latent_diffusion/modules/audiomae/util/lars.py new file mode 100755 index 0000000000000000000000000000000000000000..fc43923d22cf2c9af4ae9166612c3f3477faf254 --- /dev/null +++ b/audioldm2/latent_diffusion/modules/audiomae/util/lars.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# LARS optimizer, implementation from MoCo v3: +# https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- + +import torch + + +class LARS(torch.optim.Optimizer): + """ + LARS optimizer, no rate scaling or weight decay for parameters <= 1D. + """ + + def __init__( + self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001 + ): + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + trust_coefficient=trust_coefficient, + ) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self): + for g in self.param_groups: + for p in g["params"]: + dp = p.grad + + if dp is None: + continue + + if p.ndim > 1: # if not normalization gamma/beta or bias + dp = dp.add(p, alpha=g["weight_decay"]) + param_norm = torch.norm(p) + update_norm = torch.norm(dp) + one = torch.ones_like(param_norm) + q = torch.where( + param_norm > 0.0, + torch.where( + update_norm > 0, + (g["trust_coefficient"] * param_norm / update_norm), + one, + ), + one, + ) + dp = dp.mul(q) + + param_state = self.state[p] + if "mu" not in param_state: + param_state["mu"] = torch.zeros_like(p) + mu = param_state["mu"] + mu.mul_(g["momentum"]).add_(dp) + p.add_(mu, alpha=-g["lr"]) diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/lr_decay.py b/audioldm2/latent_diffusion/modules/audiomae/util/lr_decay.py new file mode 100755 index 0000000000000000000000000000000000000000..e90ed69d7b8d019dbf5d90571541668e2bd8efe8 --- /dev/null +++ b/audioldm2/latent_diffusion/modules/audiomae/util/lr_decay.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ELECTRA https://github.com/google-research/electra +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + + +def param_groups_lrd( + model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75 +): + """ + Parameter groups for layer-wise lr decay + Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 + """ + param_group_names = {} + param_groups = {} + + num_layers = len(model.blocks) + 1 + + layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + + # no decay: all 1D parameters and model specific ones + if p.ndim == 1 or n in no_weight_decay_list: + g_decay = "no_decay" + this_decay = 0.0 + else: + g_decay = "decay" + this_decay = weight_decay + + layer_id = get_layer_id_for_vit(n, num_layers) + group_name = "layer_%d_%s" % (layer_id, g_decay) + + if group_name not in param_group_names: + this_scale = layer_scales[layer_id] + + param_group_names[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + param_groups[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + + param_group_names[group_name]["params"].append(n) + param_groups[group_name]["params"].append(p) + + # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) + + return list(param_groups.values()) + + +def get_layer_id_for_vit(name, num_layers): + """ + Assign a parameter with its layer id + Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + """ + if name in ["cls_token", "pos_embed"]: + return 0 + elif name.startswith("patch_embed"): + return 0 + elif name.startswith("blocks"): + return int(name.split(".")[1]) + 1 + else: + return num_layers diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/lr_sched.py b/audioldm2/latent_diffusion/modules/audiomae/util/lr_sched.py new file mode 100755 index 0000000000000000000000000000000000000000..efe184d8e3fb63ec6b4f83375b6ea719985900de --- /dev/null +++ b/audioldm2/latent_diffusion/modules/audiomae/util/lr_sched.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate with half-cycle cosine after warmup""" + if epoch < args.warmup_epochs: + lr = args.lr * epoch / args.warmup_epochs + else: + lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( + 1.0 + + math.cos( + math.pi + * (epoch - args.warmup_epochs) + / (args.epochs - args.warmup_epochs) + ) + ) + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr + return lr diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/misc.py b/audioldm2/latent_diffusion/modules/audiomae/util/misc.py new file mode 100755 index 0000000000000000000000000000000000000000..74184e09e23e0e174350b894b0cff29600c18b71 --- /dev/null +++ b/audioldm2/latent_diffusion/modules/audiomae/util/misc.py @@ -0,0 +1,453 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import builtins +import datetime +import os +import time +from collections import defaultdict, deque +from pathlib import Path + +import torch +import torch.distributed as dist +from torch._six import inf + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + log_msg = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_msg.append("max mem: {memory:.0f}") + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print("[{}] ".format(now), end="") # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if args.dist_on_itp: + args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + args.dist_url = "tcp://%s:%s" % ( + os.environ["MASTER_ADDR"], + os.environ["MASTER_PORT"], + ) + os.environ["LOCAL_RANK"] = str(args.gpu) + os.environ["RANK"] = str(args.rank) + os.environ["WORLD_SIZE"] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + else: + print("Not using distributed mode") + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}): {}, gpu {}".format( + args.rank, args.dist_url, args.gpu + ), + flush=True, + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__( + self, + loss, + optimizer, + clip_grad=None, + parameters=None, + create_graph=False, + update_grad=True, + ): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_( + optimizer + ) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.0) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm( + torch.stack( + [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] + ), + norm_type, + ) + return total_norm + + +def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): + output_dir = Path(args.output_dir) + epoch_name = str(epoch) + if loss_scaler is not None: + checkpoint_paths = [output_dir / ("checkpoint-%s.pth" % epoch_name)] + for checkpoint_path in checkpoint_paths: + to_save = { + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch, + "scaler": loss_scaler.state_dict(), + "args": args, + } + + save_on_master(to_save, checkpoint_path) + else: + client_state = {"epoch": epoch} + model.save_checkpoint( + save_dir=args.output_dir, + tag="checkpoint-%s" % epoch_name, + client_state=client_state, + ) + + +def load_model(args, model_without_ddp, optimizer, loss_scaler): + if args.resume: + if args.resume.startswith("https"): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location="cpu", check_hash=True + ) + else: + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"]) + print("Resume checkpoint %s" % args.resume) + if ( + "optimizer" in checkpoint + and "epoch" in checkpoint + and not (hasattr(args, "eval") and args.eval) + ): + optimizer.load_state_dict(checkpoint["optimizer"]) + args.start_epoch = checkpoint["epoch"] + 1 + if "scaler" in checkpoint: + loss_scaler.load_state_dict(checkpoint["scaler"]) + print("With optim & sched!") + + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x + + +# utils +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +def merge_vmae_to_avmae(avmae_state_dict, vmae_ckpt): + # keys_to_copy=['pos_embed','patch_embed'] + # replaced=0 + + vmae_ckpt["cls_token"] = vmae_ckpt["cls_token_v"] + vmae_ckpt["mask_token"] = vmae_ckpt["mask_token_v"] + + # pos_emb % not trainable, use default + pos_embed_v = vmae_ckpt["pos_embed_v"] # 1,589,768 + pos_embed = pos_embed_v[:, 1:, :] # 1,588,768 + cls_embed = pos_embed_v[:, 0, :].unsqueeze(1) + pos_embed = pos_embed.reshape(1, 2, 14, 14, 768).sum(dim=1) # 1, 14, 14, 768 + print("Position interpolate from 14,14 to 64,8") + pos_embed = pos_embed.permute(0, 3, 1, 2) # 1, 14,14,768 -> 1,768,14,14 + pos_embed = torch.nn.functional.interpolate( + pos_embed, size=(64, 8), mode="bicubic", align_corners=False + ) + pos_embed = pos_embed.permute(0, 2, 3, 1).flatten( + 1, 2 + ) # 1, 14, 14, 768 => 1, 196,768 + pos_embed = torch.cat((cls_embed, pos_embed), dim=1) + assert vmae_ckpt["pos_embed"].shape == pos_embed.shape + vmae_ckpt["pos_embed"] = pos_embed + # patch_emb + # aggregate 3 channels in video-rgb ckpt to 1 channel for audio + v_weight = vmae_ckpt["patch_embed_v.proj.weight"] # 768,3,2,16,16 + new_proj_weight = torch.nn.Parameter(v_weight.sum(dim=2).sum(dim=1).unsqueeze(1)) + assert new_proj_weight.shape == vmae_ckpt["patch_embed.proj.weight"].shape + vmae_ckpt["patch_embed.proj.weight"] = new_proj_weight + vmae_ckpt["patch_embed.proj.bias"] = vmae_ckpt["patch_embed_v.proj.bias"] + + # hack + vmae_ckpt["norm.weight"] = vmae_ckpt["norm_v.weight"] + vmae_ckpt["norm.bias"] = vmae_ckpt["norm_v.bias"] + + # replace transformer encoder + for k, v in vmae_ckpt.items(): + if k.startswith("blocks."): + kk = k.replace("blocks.", "blocks_v.") + vmae_ckpt[k] = vmae_ckpt[kk] + elif k.startswith("blocks_v."): + pass + else: + print(k) + print(k) diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/patch_embed.py b/audioldm2/latent_diffusion/modules/audiomae/util/patch_embed.py new file mode 100755 index 0000000000000000000000000000000000000000..ac1e4d436c6f79aef9bf1de32cdac5d4f037c775 --- /dev/null +++ b/audioldm2/latent_diffusion/modules/audiomae/util/patch_embed.py @@ -0,0 +1,127 @@ +import torch +import torch.nn as nn +from timm.models.layers import to_2tuple + + +class PatchEmbed_org(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + y = x.flatten(2).transpose(1, 2) + return y + + +class PatchEmbed_new(nn.Module): + """Flexible Image to Patch Embedding""" + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10 + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + stride = to_2tuple(stride) + + self.img_size = img_size + self.patch_size = patch_size + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=stride + ) # with overlapped patches + # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) + # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w + self.patch_hw = (h, w) + self.num_patches = h * w + + def get_output_shape(self, img_size): + # todo: don't be lazy.. + return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + # x = self.proj(x).flatten(2).transpose(1, 2) + x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12 + x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212 + x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768 + return x + + +class PatchEmbed3D_new(nn.Module): + """Flexible Image to Patch Embedding""" + + def __init__( + self, + video_size=(16, 224, 224), + patch_size=(2, 16, 16), + in_chans=3, + embed_dim=768, + stride=(2, 16, 16), + ): + super().__init__() + + self.video_size = video_size + self.patch_size = patch_size + self.in_chans = in_chans + + self.proj = nn.Conv3d( + in_chans, embed_dim, kernel_size=patch_size, stride=stride + ) + _, _, t, h, w = self.get_output_shape(video_size) # n, emb_dim, h, w + self.patch_thw = (t, h, w) + self.num_patches = t * h * w + + def get_output_shape(self, video_size): + # todo: don't be lazy.. + return self.proj( + torch.randn(1, self.in_chans, video_size[0], video_size[1], video_size[2]) + ).shape + + def forward(self, x): + B, C, T, H, W = x.shape + x = self.proj(x) # 32, 3, 16, 224, 224 -> 32, 768, 8, 14, 14 + x = x.flatten(2) # 32, 768, 1568 + x = x.transpose(1, 2) # 32, 768, 1568 -> 32, 1568, 768 + return x + + +if __name__ == "__main__": + # patch_emb = PatchEmbed_new(img_size=224, patch_size=16, in_chans=1, embed_dim=64, stride=(16,16)) + # input = torch.rand(8,1,1024,128) + # output = patch_emb(input) + # print(output.shape) # (8,512,64) + + patch_emb = PatchEmbed3D_new( + video_size=(6, 224, 224), + patch_size=(2, 16, 16), + in_chans=3, + embed_dim=768, + stride=(2, 16, 16), + ) + input = torch.rand(8, 3, 6, 224, 224) + output = patch_emb(input) + print(output.shape) # (8,64) diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/pos_embed.py b/audioldm2/latent_diffusion/modules/audiomae/util/pos_embed.py new file mode 100755 index 0000000000000000000000000000000000000000..2d9177ed98dffcf35264f38aff94e7f00fb50abf --- /dev/null +++ b/audioldm2/latent_diffusion/modules/audiomae/util/pos_embed.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np + +import torch + + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size[0], dtype=np.float32) + grid_w = np.arange(grid_size[1], dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + # omega = np.arange(embed_dim // 2, dtype=np.float) + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size, orig_size, new_size, new_size) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +def interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, new_size): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + # orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + # new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size[0], orig_size[1], new_size[0], new_size[1]) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size[0], orig_size[1], embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size[0], new_size[1]), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +def interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_size): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + model.pos_embed.shape[-2] - num_patches + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size[0], orig_size[1], new_size[0], new_size[1]) + ) + # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + cls_token = pos_embed_checkpoint[:, 0, :].unsqueeze(1) + pos_tokens = pos_embed_checkpoint[:, 1:, :] # remove + pos_tokens = pos_tokens.reshape( + -1, orig_size[0], orig_size[1], embedding_size + ) # .permute(0, 3, 1, 2) + # pos_tokens = torch.nn.functional.interpolate( + # pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False) + + # pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + pos_tokens = pos_tokens[:, :, : new_size[1], :] # assume only time diff + pos_tokens = pos_tokens.flatten(1, 2) + new_pos_embed = torch.cat((cls_token, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +def interpolate_patch_embed_audio( + model, + checkpoint_model, + orig_channel, + new_channel=1, + kernel_size=(16, 16), + stride=(16, 16), + padding=(0, 0), +): + if orig_channel != new_channel: + if "patch_embed.proj.weight" in checkpoint_model: + # aggregate 3 channels in rgb ckpt to 1 channel for audio + new_proj_weight = torch.nn.Parameter( + torch.sum(checkpoint_model["patch_embed.proj.weight"], dim=1).unsqueeze( + 1 + ) + ) + checkpoint_model["patch_embed.proj.weight"] = new_proj_weight diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/stat.py b/audioldm2/latent_diffusion/modules/audiomae/util/stat.py new file mode 100755 index 0000000000000000000000000000000000000000..3f8137249503f6eaa25c3170fe5ef6b87f187347 --- /dev/null +++ b/audioldm2/latent_diffusion/modules/audiomae/util/stat.py @@ -0,0 +1,76 @@ +import numpy as np +from scipy import stats +from sklearn import metrics +import torch + + +def d_prime(auc): + standard_normal = stats.norm() + d_prime = standard_normal.ppf(auc) * np.sqrt(2.0) + return d_prime + + +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +def calculate_stats(output, target): + """Calculate statistics including mAP, AUC, etc. + + Args: + output: 2d array, (samples_num, classes_num) + target: 2d array, (samples_num, classes_num) + + Returns: + stats: list of statistic of each class. + """ + + classes_num = target.shape[-1] + stats = [] + + # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet + acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1)) + + # Class-wise statistics + for k in range(classes_num): + # Average precision + avg_precision = metrics.average_precision_score( + target[:, k], output[:, k], average=None + ) + + # AUC + # auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None) + + # Precisions, recalls + (precisions, recalls, thresholds) = metrics.precision_recall_curve( + target[:, k], output[:, k] + ) + + # FPR, TPR + (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k]) + + save_every_steps = 1000 # Sample statistics to reduce size + dict = { + "precisions": precisions[0::save_every_steps], + "recalls": recalls[0::save_every_steps], + "AP": avg_precision, + "fpr": fpr[0::save_every_steps], + "fnr": 1.0 - tpr[0::save_every_steps], + # 'auc': auc, + # note acc is not class-wise, this is just to keep consistent with other metrics + "acc": acc, + } + stats.append(dict) + + return stats diff --git a/audioldm2/latent_diffusion/modules/diffusionmodules/__init__.py b/audioldm2/latent_diffusion/modules/diffusionmodules/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm2/latent_diffusion/modules/diffusionmodules/model.py b/audioldm2/latent_diffusion/modules/diffusionmodules/model.py new file mode 100755 index 0000000000000000000000000000000000000000..851f8dd28e80046c5e3c9d95bd37726024f1367c --- /dev/null +++ b/audioldm2/latent_diffusion/modules/diffusionmodules/model.py @@ -0,0 +1,1069 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from audioldm2.latent_diffusion.util import instantiate_from_config +from audioldm2.latent_diffusion.modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class UpsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=1, padding=2 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class DownsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_ + ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + downsample_time_stride4_levels=[], + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level in self.downsample_time_stride4_levels: + down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) + else: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + downsample_time_stride4_levels=[], + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # compute in_ch_mult, block_in and curr_res at lowest res + (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + # print( + # "Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape) + # ) + # ) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level - 1 in self.downsample_time_stride4_levels: + up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) + else: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1 + ) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=( + int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)), + ), + ) + x = self.attn(x).contiguous() + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels, + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor + ) + return x + + +class FirstStagePostProcessor(nn.Module): + def __init__( + self, + ch_mult: list, + in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0.0, + pretrained_config=None, + ): + super().__init__() + if pretrained_config is None: + assert ( + pretrained_model is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert ( + pretrained_config is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d( + in_channels, n_channels, kernel_size=3, stride=1, padding=1 + ) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append( + ResnetBlock( + in_channels=ch_in, out_channels=m * n_channels, dropout=dropout + ) + ) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode_with_pretrained(self, x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self, x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z, "b c h w -> b (h w) c") + return z diff --git a/audioldm2/latent_diffusion/modules/diffusionmodules/openaimodel.py b/audioldm2/latent_diffusion/modules/diffusionmodules/openaimodel.py new file mode 100755 index 0000000000000000000000000000000000000000..e006e5a332c3cde5f4e221f003b270d86b34e933 --- /dev/null +++ b/audioldm2/latent_diffusion/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,1103 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from audioldm2.latent_diffusion.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from audioldm2.latent_diffusion.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1).contiguous() # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context_list=None, mask_list=None): + # The first spatial transformer block does not have context + spatial_transformer_id = 0 + context_list = [None] + context_list + mask_list = [None] + mask_list + + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + if spatial_transformer_id >= len(context_list): + context, mask = None, None + else: + context, mask = ( + context_list[spatial_transformer_id], + mask_list[spatial_transformer_id], + ) + + x = layer(x, context, mask=mask) + spatial_transformer_id += 1 + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2 + ) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1).contiguous() + qkv = self.qkv(self.norm(x)).contiguous() + h = self.attention(qkv).contiguous() + h = self.proj_out(h).contiguous() + return (x + h).reshape(b, c, *spatial).contiguous() + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = ( + qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1) + ) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum( + "bts,bcs->bct", + weight, + v.reshape(bs * self.n_heads, ch, length).contiguous(), + ) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + extra_sa_layer=True, + num_classes=None, + extra_film_condition_dim=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=True, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.extra_film_condition_dim = extra_film_condition_dim + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + # assert not ( + # self.num_classes is not None and self.extra_film_condition_dim is not None + # ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim." + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.use_extra_film_by_concat = self.extra_film_condition_dim is not None + + if self.extra_film_condition_dim is not None: + self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim) + print( + "+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. " + % self.extra_film_condition_dim + ) + + if context_dim is not None and not use_spatial_transformer: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + + if context_dim is not None and not isinstance(context_dim, list): + context_dim = [context_dim] + elif context_dim is None: + context_dim = [None] # At least use one spatial transformer + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if extra_sa_layer: + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=None, + ) + ) + for context_dim_id in range(len(context_dim)): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + middle_layers = [ + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + if extra_sa_layer: + middle_layers.append( + SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=None + ) + ) + for context_dim_id in range(len(context_dim)): + middle_layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + middle_layers.append( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + self.middle_block = TimestepEmbedSequential(*middle_layers) + + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if extra_sa_layer: + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=None, + ) + ) + for context_dim_id in range(len(context_dim)): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + self.shape_reported = False + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward( + self, + x, + timesteps=None, + y=None, + context_list=None, + context_attn_mask_list=None, + **kwargs, + ): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional + :return: an [N x C x ...] Tensor of outputs. + """ + if not self.shape_reported: + # print("The shape of UNet input is", x.size()) + self.shape_reported = True + + assert (y is not None) == ( + self.num_classes is not None or self.extra_film_condition_dim is not None + ), "must specify y if and only if the model is class-conditional or film embedding conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + # if self.num_classes is not None: + # assert y.shape == (x.shape[0],) + # emb = emb + self.label_emb(y) + + if self.use_extra_film_by_concat: + emb = th.cat([emb, self.film_emb(y)], dim=-1) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context_list, context_attn_mask_list) + hs.append(h) + h = self.middle_block(h, emb, context_list, context_attn_mask_list) + for module in self.output_blocks: + concate_tensor = hs.pop() + h = th.cat([h, concate_tensor], dim=1) + h = module(h, emb, context_list, context_attn_mask_list) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) diff --git a/audioldm2/latent_diffusion/modules/diffusionmodules/util.py b/audioldm2/latent_diffusion/modules/diffusionmodules/util.py new file mode 100755 index 0000000000000000000000000000000000000000..0d486f919a7ccc0586bc40225dac0ffb33aed01c --- /dev/null +++ b/audioldm2/latent_diffusion/modules/diffusionmodules/util.py @@ -0,0 +1,294 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from audioldm2.latent_diffusion.util import instantiate_from_config + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t).contiguous() + return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous() + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/audioldm2/latent_diffusion/modules/ema.py b/audioldm2/latent_diffusion/modules/ema.py new file mode 100755 index 0000000000000000000000000000000000000000..880ca3d205d9b4d7450e146930a93f2e63c58b70 --- /dev/null +++ b/audioldm2/latent_diffusion/modules/ema.py @@ -0,0 +1,82 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/audioldm2/latent_diffusion/modules/encoders/__init__.py b/audioldm2/latent_diffusion/modules/encoders/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm2/latent_diffusion/modules/encoders/modules.py b/audioldm2/latent_diffusion/modules/encoders/modules.py new file mode 100755 index 0000000000000000000000000000000000000000..7a72339840c0c3b667e907ea07ee7cb755eb66fd --- /dev/null +++ b/audioldm2/latent_diffusion/modules/encoders/modules.py @@ -0,0 +1,736 @@ +import torch +import logging +import torch.nn as nn +from audioldm2.clap.open_clip import create_model +from audioldm2.clap.training.data import get_audio_features +import torchaudio +from transformers import RobertaTokenizer, AutoTokenizer, T5EncoderModel +import torch.nn.functional as F +from audioldm2.latent_diffusion.modules.audiomae.AudioMAE import Vanilla_AudioMAE +from audioldm2.latent_diffusion.modules.phoneme_encoder.encoder import TextEncoder + +from transformers import AutoTokenizer, T5Config + +from audioldm2.audiomae_gen.sequence_input import Sequence2AudioMAE +import numpy as np + +""" +The model forward function can return three types of data: +1. tensor: used directly as conditioning signal +2. dict: where there is a main key as condition, there are also other key that you can use to pass loss function and itermediate result. etc. +3. list: the length is 2, in which the first element is tensor, the second element is attntion mask. + +The output shape for the cross attention condition should be: +x,x_mask = [bs, seq_len, emb_dim], [bs, seq_len] + +All the returned data, in which will be used as diffusion input, will need to be in float type +""" + + +class PhonemeEncoder(nn.Module): + def __init__(self, vocabs_size=41, pad_length=250, pad_token_id=None): + super().__init__() + """ + encoder = PhonemeEncoder(40) + data = torch.randint(0, 39, (2, 250)) + output = encoder(data) + import ipdb;ipdb.set_trace() + """ + assert pad_token_id is not None + + self.device = None + self.PAD_LENGTH = int(pad_length) + self.pad_token_id = pad_token_id + self.pad_token_sequence = torch.tensor([self.pad_token_id] * self.PAD_LENGTH) + + self.text_encoder = TextEncoder( + n_vocab=vocabs_size, + out_channels=192, + hidden_channels=192, + filter_channels=768, + n_heads=2, + n_layers=6, + kernel_size=3, + p_dropout=0.1, + ) + + self.learnable_positional_embedding = torch.nn.Parameter( + torch.zeros((1, 192, self.PAD_LENGTH)) + ) # [batchsize, seqlen, padlen] + self.learnable_positional_embedding.requires_grad = True + + # Required + def get_unconditional_condition(self, batchsize): + unconditional_tokens = self.pad_token_sequence.expand( + batchsize, self.PAD_LENGTH + ) + return self(unconditional_tokens) # Need to return float type + + # def get_unconditional_condition(self, batchsize): + + # hidden_state = torch.zeros((batchsize, self.PAD_LENGTH, 192)).to(self.device) + # attention_mask = torch.ones((batchsize, self.PAD_LENGTH)).to(self.device) + # return [hidden_state, attention_mask] # Need to return float type + + def _get_src_mask(self, phoneme): + src_mask = phoneme != self.pad_token_id + return src_mask + + def _get_src_length(self, phoneme): + src_mask = self._get_src_mask(phoneme) + length = torch.sum(src_mask, dim=-1) + return length + + # def make_empty_condition_unconditional(self, src_length, text_emb, attention_mask): + # # src_length: [bs] + # # text_emb: [bs, 192, pad_length] + # # attention_mask: [bs, pad_length] + # mask = src_length[..., None, None] > 1 + # text_emb = text_emb * mask + + # attention_mask[src_length < 1] = attention_mask[src_length < 1] * 0.0 + 1.0 + # return text_emb, attention_mask + + def forward(self, phoneme_idx): + if self.device is None: + self.device = self.learnable_positional_embedding.device + self.pad_token_sequence = self.pad_token_sequence.to(self.device) + + src_length = self._get_src_length(phoneme_idx) + text_emb, m, logs, text_emb_mask = self.text_encoder(phoneme_idx, src_length) + text_emb = text_emb + self.learnable_positional_embedding + + # text_emb, text_emb_mask = self.make_empty_condition_unconditional(src_length, text_emb, text_emb_mask) + + return [ + text_emb.permute(0, 2, 1), + text_emb_mask.squeeze(1), + ] # [2, 250, 192], [2, 250] + + +class FlanT5HiddenState(nn.Module): + """ + llama = FlanT5HiddenState() + data = ["","this is not an empty sentence"] + encoder_hidden_states = llama(data) + import ipdb;ipdb.set_trace() + """ + + def __init__( + self, text_encoder_name="google/flan-t5-large", freeze_text_encoder=True + ): + super().__init__() + self.freeze_text_encoder = freeze_text_encoder + self.tokenizer = AutoTokenizer.from_pretrained(text_encoder_name) + self.model = T5EncoderModel(T5Config.from_pretrained(text_encoder_name)) + if freeze_text_encoder: + self.model.eval() + for p in self.model.parameters(): + p.requires_grad = False + else: + print("=> The text encoder is learnable") + + self.empty_hidden_state_cfg = None + self.device = None + + # Required + def get_unconditional_condition(self, batchsize): + param = next(self.model.parameters()) + if self.freeze_text_encoder: + assert param.requires_grad == False + + # device = param.device + if self.empty_hidden_state_cfg is None: + self.empty_hidden_state_cfg, _ = self([""]) + + hidden_state = torch.cat([self.empty_hidden_state_cfg] * batchsize).float() + attention_mask = ( + torch.ones((batchsize, hidden_state.size(1))) + .to(hidden_state.device) + .float() + ) + return [hidden_state, attention_mask] # Need to return float type + + def forward(self, batch): + param = next(self.model.parameters()) + if self.freeze_text_encoder: + assert param.requires_grad == False + + if self.device is None: + self.device = param.device + + # print("Manually change text") + # for i in range(len(batch)): + # batch[i] = "dog barking" + try: + return self.encode_text(batch) + except Exception as e: + print(e, batch) + logging.exception("An error occurred: %s", str(e)) + + def encode_text(self, prompt): + device = self.model.device + batch = self.tokenizer( + prompt, + max_length=128, # self.tokenizer.model_max_length + padding=True, + truncation=True, + return_tensors="pt", + ) + input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to( + device + ) + # Get text encoding + if self.freeze_text_encoder: + with torch.no_grad(): + encoder_hidden_states = self.model( + input_ids=input_ids, attention_mask=attention_mask + )[0] + else: + encoder_hidden_states = self.model( + input_ids=input_ids, attention_mask=attention_mask + )[0] + return [ + encoder_hidden_states.detach(), + attention_mask.float(), + ] # Attention mask == 1 means usable token + + +class SequenceGenAudioMAECond(Sequence2AudioMAE): + def __init__( + self, + cond_stage_config, + base_learning_rate, + sequence_gen_length, + sequence_input_key, + sequence_input_embed_dim, + batchsize, + always_output_audiomae_gt=False, + pretrained_path=None, + force_reload_pretrain_avoid_overwrite=False, + learnable=True, + use_warmup=True, + device=None, + use_gt_mae_output=None, # False: does not use AudioMAE GT, True: Use AudioMAE GT + use_gt_mae_prob=None, + ): # The prob of using AudioMAE GT + if use_warmup: + use_warmup = False + + super().__init__( + base_learning_rate=base_learning_rate, + cond_stage_config=cond_stage_config, + sequence_gen_length=sequence_gen_length, + sequence_input_key=sequence_input_key, + use_warmup=use_warmup, + sequence_input_embed_dim=sequence_input_embed_dim, + batchsize=batchsize, + ) + + assert use_gt_mae_output is not None and use_gt_mae_prob is not None + self.always_output_audiomae_gt = always_output_audiomae_gt + self.force_reload_pretrain_avoid_overwrite = ( + force_reload_pretrain_avoid_overwrite + ) + self.pretrained_path = pretrained_path + self.device = device + if self.force_reload_pretrain_avoid_overwrite: + self.is_reload = False + else: + self.is_reload = True + + self.load_pretrain_model() + + self.use_gt_mae_output = use_gt_mae_output + self.use_gt_mae_prob = use_gt_mae_prob + self.learnable = learnable + + if not learnable: + # Only optimize the GPT2 model + for p in self.model.parameters(): + p.requires_grad = False + self.eval() + + def load_pretrain_model(self): + if self.pretrained_path is not None: + print("Reload SequenceGenAudioMAECond from %s" % self.pretrained_path) + state_dict = torch.load(self.pretrained_path)["state_dict"] + self.load_state_dict(state_dict) + + # Required + def get_unconditional_condition(self, batchsize): + return_dict = self.cfg_uncond(batchsize) + return_dict["crossattn_audiomae_generated"] = [ + return_dict["crossattn_audiomae_pooled"][0], + torch.ones_like(return_dict["crossattn_audiomae_pooled"][1]).float(), + ] + return return_dict + + def forward(self, batch): + # The conditional module can return both tensor or dictionaries + # The returned tensor will be corresponding to the cond_stage_key + # The returned dict will have keys that correspond to the cond_stage_key + ret_dict = {} + + if self.force_reload_pretrain_avoid_overwrite and not self.is_reload: + self.load_pretrain_model() + self.is_reload = True + + # if(self.always_output_audiomae_gt or (self.use_gt_mae_output and torch.rand(1).item() < self.use_gt_mae_prob)): + # cond_dict = self.get_input(batch) + # ret_dict["crossattn_audiomae_generated"] = [cond_dict["crossattn_audiomae_pooled"][0], torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float()] # Input sequence and mask + # else: + input_embeds, cond_dict = self.generate(batch) + input_embeds_mask = ( + torch.ones((input_embeds.size(0), input_embeds.size(1))) + .to(input_embeds.device) + .float() + ) + ret_dict["crossattn_audiomae_generated"] = [ + input_embeds, + input_embeds_mask, + ] # Input sequence and mask + + # If the following two keys are not in cond_stage_key, then they will not be used as condition + for key in cond_dict.keys(): + ret_dict[key] = cond_dict[key] + + return ret_dict + + +class AudioMAEConditionCTPoolRandTFSeparated(nn.Module): + """ + audiomae = AudioMAEConditionCTPool2x2() + data = torch.randn((4, 1024, 128)) + output = audiomae(data) + import ipdb;ipdb.set_trace() + exit(0) + """ + + def __init__( + self, + time_pooling_factors=[1, 2, 4, 8], + freq_pooling_factors=[1, 2, 4, 8], + eval_time_pooling=None, + eval_freq_pooling=None, + mask_ratio=0.0, + regularization=False, + no_audiomae_mask=True, + no_audiomae_average=False, + ): + super().__init__() + self.device = None + self.time_pooling_factors = time_pooling_factors + self.freq_pooling_factors = freq_pooling_factors + self.no_audiomae_mask = no_audiomae_mask + self.no_audiomae_average = no_audiomae_average + + self.eval_freq_pooling = eval_freq_pooling + self.eval_time_pooling = eval_time_pooling + self.mask_ratio = mask_ratio + self.use_reg = regularization + + self.audiomae = Vanilla_AudioMAE() + self.audiomae.eval() + for p in self.audiomae.parameters(): + p.requires_grad = False + + # Required + def get_unconditional_condition(self, batchsize): + param = next(self.audiomae.parameters()) + assert param.requires_grad == False + device = param.device + # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors) + time_pool, freq_pool = min(self.eval_time_pooling, 64), min( + self.eval_freq_pooling, 8 + ) + # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))] + # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] + token_num = int(512 / (time_pool * freq_pool)) + return [ + torch.zeros((batchsize, token_num, 768)).to(device).float(), + torch.ones((batchsize, token_num)).to(device).float(), + ] + + def pool(self, representation, time_pool=None, freq_pool=None): + assert representation.size(-1) == 768 + representation = representation[:, 1:, :].transpose(1, 2) + bs, embedding_dim, token_num = representation.size() + representation = representation.reshape(bs, embedding_dim, 64, 8) + + if self.training: + if time_pool is None and freq_pool is None: + time_pool = min( + 64, + self.time_pooling_factors[ + np.random.choice(list(range(len(self.time_pooling_factors)))) + ], + ) + freq_pool = min( + 8, + self.freq_pooling_factors[ + np.random.choice(list(range(len(self.freq_pooling_factors)))) + ], + ) + # freq_pool = min(8, time_pool) # TODO here I make some modification. + else: + time_pool, freq_pool = min(self.eval_time_pooling, 64), min( + self.eval_freq_pooling, 8 + ) + + self.avgpooling = nn.AvgPool2d( + kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) + ) + self.maxpooling = nn.MaxPool2d( + kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) + ) + + pooled = ( + self.avgpooling(representation) + self.maxpooling(representation) + ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num] + pooled = pooled.flatten(2).transpose(1, 2) + return pooled # [bs, token_num, embedding_dim] + + def regularization(self, x): + assert x.size(-1) == 768 + x = F.normalize(x, p=2, dim=-1) + return x + + # Required + def forward(self, batch, time_pool=None, freq_pool=None): + assert batch.size(-2) == 1024 and batch.size(-1) == 128 + + if self.device is None: + self.device = batch.device + + batch = batch.unsqueeze(1) + with torch.no_grad(): + representation = self.audiomae( + batch, + mask_ratio=self.mask_ratio, + no_mask=self.no_audiomae_mask, + no_average=self.no_audiomae_average, + ) + representation = self.pool(representation, time_pool, freq_pool) + if self.use_reg: + representation = self.regularization(representation) + return [ + representation, + torch.ones((representation.size(0), representation.size(1))) + .to(representation.device) + .float(), + ] + + +class AudioMAEConditionCTPoolRand(nn.Module): + """ + audiomae = AudioMAEConditionCTPool2x2() + data = torch.randn((4, 1024, 128)) + output = audiomae(data) + import ipdb;ipdb.set_trace() + exit(0) + """ + + def __init__( + self, + time_pooling_factors=[1, 2, 4, 8], + freq_pooling_factors=[1, 2, 4, 8], + eval_time_pooling=None, + eval_freq_pooling=None, + mask_ratio=0.0, + regularization=False, + no_audiomae_mask=True, + no_audiomae_average=False, + ): + super().__init__() + self.device = None + self.time_pooling_factors = time_pooling_factors + self.freq_pooling_factors = freq_pooling_factors + self.no_audiomae_mask = no_audiomae_mask + self.no_audiomae_average = no_audiomae_average + + self.eval_freq_pooling = eval_freq_pooling + self.eval_time_pooling = eval_time_pooling + self.mask_ratio = mask_ratio + self.use_reg = regularization + + self.audiomae = Vanilla_AudioMAE() + self.audiomae.eval() + for p in self.audiomae.parameters(): + p.requires_grad = False + + # Required + def get_unconditional_condition(self, batchsize): + param = next(self.audiomae.parameters()) + assert param.requires_grad == False + device = param.device + # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors) + time_pool, freq_pool = min(self.eval_time_pooling, 64), min( + self.eval_freq_pooling, 8 + ) + # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))] + # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] + token_num = int(512 / (time_pool * freq_pool)) + return [ + torch.zeros((batchsize, token_num, 768)).to(device).float(), + torch.ones((batchsize, token_num)).to(device).float(), + ] + + def pool(self, representation, time_pool=None, freq_pool=None): + assert representation.size(-1) == 768 + representation = representation[:, 1:, :].transpose(1, 2) + bs, embedding_dim, token_num = representation.size() + representation = representation.reshape(bs, embedding_dim, 64, 8) + + if self.training: + if time_pool is None and freq_pool is None: + time_pool = min( + 64, + self.time_pooling_factors[ + np.random.choice(list(range(len(self.time_pooling_factors)))) + ], + ) + # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] + freq_pool = min(8, time_pool) # TODO here I make some modification. + else: + time_pool, freq_pool = min(self.eval_time_pooling, 64), min( + self.eval_freq_pooling, 8 + ) + + self.avgpooling = nn.AvgPool2d( + kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) + ) + self.maxpooling = nn.MaxPool2d( + kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) + ) + + pooled = ( + self.avgpooling(representation) + self.maxpooling(representation) + ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num] + pooled = pooled.flatten(2).transpose(1, 2) + return pooled # [bs, token_num, embedding_dim] + + def regularization(self, x): + assert x.size(-1) == 768 + x = F.normalize(x, p=2, dim=-1) + return x + + # Required + def forward(self, batch, time_pool=None, freq_pool=None): + assert batch.size(-2) == 1024 and batch.size(-1) == 128 + + if self.device is None: + self.device = batch.device + + batch = batch.unsqueeze(1) + with torch.no_grad(): + representation = self.audiomae( + batch, + mask_ratio=self.mask_ratio, + no_mask=self.no_audiomae_mask, + no_average=self.no_audiomae_average, + ) + representation = self.pool(representation, time_pool, freq_pool) + if self.use_reg: + representation = self.regularization(representation) + return [ + representation, + torch.ones((representation.size(0), representation.size(1))) + .to(representation.device) + .float(), + ] + + +class CLAPAudioEmbeddingClassifierFreev2(nn.Module): + def __init__( + self, + pretrained_path="", + sampling_rate=16000, + embed_mode="audio", + amodel="HTSAT-base", + unconditional_prob=0.1, + random_mute=False, + max_random_mute_portion=0.5, + training_mode=True, + ): + super().__init__() + self.device = "cpu" + self.precision = "fp32" + self.amodel = amodel # or 'PANN-14' + self.tmodel = "roberta" # the best text encoder in our training + self.enable_fusion = False # False if you do not want to use the fusion model + self.fusion_type = "aff_2d" + self.pretrained = pretrained_path + self.embed_mode = embed_mode + self.embed_mode_orig = embed_mode + self.sampling_rate = sampling_rate + self.unconditional_prob = unconditional_prob + self.random_mute = random_mute + self.tokenize = RobertaTokenizer.from_pretrained("roberta-base") + self.max_random_mute_portion = max_random_mute_portion + self.training_mode = training_mode + self.model, self.model_cfg = create_model( + self.amodel, + self.tmodel, + self.pretrained, + precision=self.precision, + device=self.device, + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + audio_cfg = self.model_cfg["audio_cfg"] + self.mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=audio_cfg["sample_rate"], + n_fft=audio_cfg["window_size"], + win_length=audio_cfg["window_size"], + hop_length=audio_cfg["hop_size"], + center=True, + pad_mode="reflect", + power=2.0, + norm=None, + onesided=True, + n_mels=64, + f_min=audio_cfg["fmin"], + f_max=audio_cfg["fmax"], + ) + for p in self.model.parameters(): + p.requires_grad = False + self.unconditional_token = None + self.model.eval() + + def get_unconditional_condition(self, batchsize): + self.unconditional_token = self.model.get_text_embedding( + self.tokenizer(["", ""]) + )[0:1] + return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0) + + def batch_to_list(self, batch): + ret = [] + for i in range(batch.size(0)): + ret.append(batch[i]) + return ret + + def make_decision(self, probability): + if float(torch.rand(1)) < probability: + return True + else: + return False + + def random_uniform(self, start, end): + val = torch.rand(1).item() + return start + (end - start) * val + + def _random_mute(self, waveform): + # waveform: [bs, t-steps] + t_steps = waveform.size(-1) + for i in range(waveform.size(0)): + mute_size = int( + self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion)) + ) + mute_start = int(self.random_uniform(0, t_steps - mute_size)) + waveform[i, mute_start : mute_start + mute_size] = 0 + return waveform + + def cos_similarity(self, waveform, text): + # waveform: [bs, t_steps] + original_embed_mode = self.embed_mode + with torch.no_grad(): + self.embed_mode = "audio" + audio_emb = self(waveform.cuda()) + self.embed_mode = "text" + text_emb = self(text) + similarity = F.cosine_similarity(audio_emb, text_emb, dim=2) + self.embed_mode = original_embed_mode + return similarity.squeeze() + + def build_unconditional_emb(self): + self.unconditional_token = self.model.get_text_embedding( + self.tokenizer(["", ""]) + )[0:1] + + def forward(self, batch): + # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0 + # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0 + if self.model.training == True and not self.training_mode: + print( + "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters." + ) + self.model, self.model_cfg = create_model( + self.amodel, + self.tmodel, + self.pretrained, + precision=self.precision, + device="cuda", + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + for p in self.model.parameters(): + p.requires_grad = False + self.model.eval() + + if self.unconditional_token is None: + self.build_unconditional_emb() + + # if(self.training_mode): + # assert self.model.training == True + # else: + # assert self.model.training == False + + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + if self.embed_mode == "audio": + if not self.training: + print("INFO: clap model calculate the audio embedding as condition") + with torch.no_grad(): + # assert ( + # self.sampling_rate == 16000 + # ), "We only support 16000 sampling rate" + + # if self.random_mute: + # batch = self._random_mute(batch) + # batch: [bs, 1, t-samples] + if self.sampling_rate != 48000: + batch = torchaudio.functional.resample( + batch, orig_freq=self.sampling_rate, new_freq=48000 + ) + + audio_data = batch.squeeze(1) + mel = self.mel_transform(audio_data) + audio_dict = get_audio_features( + audio_data, + mel, + 480000, + data_truncating="fusion", + data_filling="repeatpad", + audio_cfg=self.model_cfg["audio_cfg"], + ) + # [bs, 512] + embed = self.model.get_audio_embedding(audio_dict) + elif self.embed_mode == "text": + with torch.no_grad(): + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + text_data = self.tokenizer(batch) + + if isinstance(batch, str) or ( + isinstance(batch, list) and len(batch) == 1 + ): + for key in text_data.keys(): + text_data[key] = text_data[key].unsqueeze(0) + + embed = self.model.get_text_embedding(text_data) + + embed = embed.unsqueeze(1) + for i in range(embed.size(0)): + if self.make_decision(self.unconditional_prob): + embed[i] = self.unconditional_token + # embed = torch.randn((batch.size(0), 1, 512)).type_as(batch) + return embed.detach() + + def tokenizer(self, text): + result = self.tokenize( + text, + padding="max_length", + truncation=True, + max_length=512, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} diff --git a/audioldm2/latent_diffusion/modules/phoneme_encoder/__init__.py b/audioldm2/latent_diffusion/modules/phoneme_encoder/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm2/latent_diffusion/modules/phoneme_encoder/attentions.py b/audioldm2/latent_diffusion/modules/phoneme_encoder/attentions.py new file mode 100755 index 0000000000000000000000000000000000000000..3553a688d41b07a45a7ced25f740a55dbc0b6d94 --- /dev/null +++ b/audioldm2/latent_diffusion/modules/phoneme_encoder/attentions.py @@ -0,0 +1,430 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +import audioldm2.latent_diffusion.modules.phoneme_encoder.commons as commons + +LRELU_SLOPE = 0.1 + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=4, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + window_size=window_size, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + proximal_bias=proximal_bias, + proximal_init=proximal_init, + ) + ) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append( + MultiHeadAttention( + hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + causal=True, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( + device=x.device, dtype=x.dtype + ) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + p_dropout=0.0, + window_size=None, + heads_share=True, + block_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + self.emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert ( + t_s == t_t + ), "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys( + query / math.sqrt(self.k_channels), key_relative_embeddings + ) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to( + device=scores.device, dtype=scores.dtype + ) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert ( + t_s == t_t + ), "Local attention is only available for self-attention." + block_mask = ( + torch.ones_like(scores) + .triu(-self.block_length) + .tril(self.block_length) + ) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_v, t_s + ) + output = output + self._matmul_with_relative_values( + relative_weights, value_relative_embeddings + ) + output = ( + output.transpose(2, 3).contiguous().view(b, d, t_t) + ) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[ + :, slice_start_position:slice_end_position + ] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad( + x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) + ) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ + :, :, :length, length - 1 : + ] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad( + x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) + ) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__( + self, + in_channels, + out_channels, + filter_channels, + kernel_size, + p_dropout=0.0, + activation=None, + causal=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x diff --git a/audioldm2/latent_diffusion/modules/phoneme_encoder/commons.py b/audioldm2/latent_diffusion/modules/phoneme_encoder/commons.py new file mode 100755 index 0000000000000000000000000000000000000000..9515724c12ab2f856b9a2ec14e38cc63df9b85d6 --- /dev/null +++ b/audioldm2/latent_diffusion/modules/phoneme_encoder/commons.py @@ -0,0 +1,161 @@ +import math +import torch +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += ( + 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + ) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( + num_timescales - 1 + ) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm diff --git a/audioldm2/latent_diffusion/modules/phoneme_encoder/encoder.py b/audioldm2/latent_diffusion/modules/phoneme_encoder/encoder.py new file mode 100755 index 0000000000000000000000000000000000000000..b39bf583b5ea88a4771181e491c8deb92b2d7559 --- /dev/null +++ b/audioldm2/latent_diffusion/modules/phoneme_encoder/encoder.py @@ -0,0 +1,50 @@ +import math +import torch +from torch import nn + +import audioldm2.latent_diffusion.modules.phoneme_encoder.commons as commons +import audioldm2.latent_diffusion.modules.phoneme_encoder.attentions as attentions + + +class TextEncoder(nn.Module): + def __init__( + self, + n_vocab, + out_channels=192, + hidden_channels=192, + filter_channels=768, + n_heads=2, + n_layers=6, + kernel_size=3, + p_dropout=0.1, + ): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + self.encoder = attentions.Encoder( + hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask diff --git a/audioldm2/latent_diffusion/util.py b/audioldm2/latent_diffusion/util.py new file mode 100755 index 0000000000000000000000000000000000000000..3dd301b1c0a39a5b905aa23f4b98d224df7d87d9 --- /dev/null +++ b/audioldm2/latent_diffusion/util.py @@ -0,0 +1,217 @@ +import importlib + +import torch +import numpy as np +from collections import abc + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join( + xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) + ) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def float32_to_int16(x): + x = np.clip(x, a_min=-1.0, a_max=1.0) + return (x * 32767.0).astype(np.int16) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, + data, + n_proc, + target_data_type="ndarray", + cpu_intensive=True, + use_worker_id=False, +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i : i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == "ndarray": + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == "list": + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/audioldm2/latent_encoder/__init__.py b/audioldm2/latent_encoder/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm2/latent_encoder/autoencoder.py b/audioldm2/latent_encoder/autoencoder.py new file mode 100755 index 0000000000000000000000000000000000000000..f07075bb76a34edd8568797961752e4957129f92 --- /dev/null +++ b/audioldm2/latent_encoder/autoencoder.py @@ -0,0 +1,326 @@ +import torch +import os + +import torch.nn.functional as F +import numpy as np +from audioldm2.latent_diffusion.modules.ema import * + +from audioldm2.latent_diffusion.modules.diffusionmodules.model import Encoder, Decoder +from audioldm2.latent_diffusion.modules.distributions.distributions import ( + DiagonalGaussianDistribution, +) +import soundfile as sf + +from audioldm2.utilities.model import get_vocoder +from audioldm2.utilities.tools import synth_one_sample + + +class AutoencoderKL(nn.Module): + def __init__( + self, + ddconfig=None, + lossconfig=None, + batchsize=None, + embed_dim=None, + time_shuffle=1, + subband=1, + sampling_rate=16000, + ckpt_path=None, + reload_from_ckpt=None, + ignore_keys=[], + image_key="fbank", + colorize_nlabels=None, + monitor=None, + base_learning_rate=1e-5, + ): + super().__init__() + self.automatic_optimization = False + assert ( + "mel_bins" in ddconfig.keys() + ), "mel_bins is not specified in the Autoencoder config" + num_mel = ddconfig["mel_bins"] + self.image_key = image_key + self.sampling_rate = sampling_rate + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + + self.loss = None + self.subband = int(subband) + + if self.subband > 1: + print("Use subband decomposition %s" % self.subband) + + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + + if self.image_key == "fbank": + self.vocoder = get_vocoder(None, "cpu", num_mel) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.learning_rate = float(base_learning_rate) + # print("Initial learning rate %s" % self.learning_rate) + + self.time_shuffle = time_shuffle + self.reload_from_ckpt = reload_from_ckpt + self.reloaded = False + self.mean, self.std = None, None + + self.feature_cache = None + self.flag_first_run = True + self.train_step = 0 + + self.logger_save_dir = None + self.logger_exp_name = None + + def get_log_dir(self): + if self.logger_save_dir is None and self.logger_exp_name is None: + return os.path.join(self.logger.save_dir, self.logger._project) + else: + return os.path.join(self.logger_save_dir, self.logger_exp_name) + + def set_log_dir(self, save_dir, exp_name): + self.logger_save_dir = save_dir + self.logger_exp_name = exp_name + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + # x = self.time_shuffle_operation(x) + # x = self.freq_split_subband(x) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + # bs, ch, shuffled_timesteps, fbins = dec.size() + # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins) + # dec = self.freq_merge_subband(dec) + return dec + + def decode_to_waveform(self, dec): + from audioldm2.utilities.model import vocoder_infer + + if self.image_key == "fbank": + dec = dec.squeeze(1).permute(0, 2, 1) + wav_reconstruction = vocoder_infer(dec, self.vocoder) + elif self.image_key == "stft": + dec = dec.squeeze(1).permute(0, 2, 1) + wav_reconstruction = self.wave_decoder(dec) + return wav_reconstruction + + def visualize_latent(self, input): + import matplotlib.pyplot as plt + + # for i in range(10): + # zero_input = torch.zeros_like(input) - 11.59 + # zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59 + + # posterior = self.encode(zero_input) + # latent = posterior.sample() + # avg_latent = torch.mean(latent, dim=1)[0] + # plt.imshow(avg_latent.cpu().detach().numpy().T) + # plt.savefig("%s.png" % i) + # plt.close() + + np.save("input.npy", input.cpu().detach().numpy()) + # zero_input = torch.zeros_like(input) - 11.59 + time_input = input.clone() + time_input[:, :, :, :32] *= 0 + time_input[:, :, :, :32] -= 11.59 + + np.save("time_input.npy", time_input.cpu().detach().numpy()) + + posterior = self.encode(time_input) + latent = posterior.sample() + np.save("time_latent.npy", latent.cpu().detach().numpy()) + avg_latent = torch.mean(latent, dim=1) + for i in range(avg_latent.size(0)): + plt.imshow(avg_latent[i].cpu().detach().numpy().T) + plt.savefig("freq_%s.png" % i) + plt.close() + + freq_input = input.clone() + freq_input[:, :, :512, :] *= 0 + freq_input[:, :, :512, :] -= 11.59 + + np.save("freq_input.npy", freq_input.cpu().detach().numpy()) + + posterior = self.encode(freq_input) + latent = posterior.sample() + np.save("freq_latent.npy", latent.cpu().detach().numpy()) + avg_latent = torch.mean(latent, dim=1) + for i in range(avg_latent.size(0)): + plt.imshow(avg_latent[i].cpu().detach().numpy().T) + plt.savefig("time_%s.png" % i) + plt.close() + + def get_input(self, batch): + fname, text, label_indices, waveform, stft, fbank = ( + batch["fname"], + batch["text"], + batch["label_vector"], + batch["waveform"], + batch["stft"], + batch["log_mel_spec"], + ) + # if(self.time_shuffle != 1): + # if(fbank.size(1) % self.time_shuffle != 0): + # pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle) + # fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len)) + + ret = {} + + ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = ( + fbank.unsqueeze(1), + stft.unsqueeze(1), + fname, + waveform.unsqueeze(1), + ) + + return ret + + def save_wave(self, batch_wav, fname, save_dir): + os.makedirs(save_dir, exist_ok=True) + + for wav, name in zip(batch_wav, fname): + name = os.path.basename(name) + + sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate) + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs): + log = dict() + x = batch.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + log["samples"] = self.decode(posterior.sample()) + log["reconstructions"] = xrec + + log["inputs"] = x + wavs = self._log_img(log, train=train, index=0, waveform=waveform) + return wavs + + def _log_img(self, log, train=True, index=0, waveform=None): + images_input = self.tensor2numpy(log["inputs"][index, 0]).T + images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T + images_samples = self.tensor2numpy(log["samples"][index, 0]).T + + if train: + name = "train" + else: + name = "val" + + if self.logger is not None: + self.logger.log_image( + "img_%s" % name, + [images_input, images_reconstruct, images_samples], + caption=["input", "reconstruct", "samples"], + ) + + inputs, reconstructions, samples = ( + log["inputs"], + log["reconstructions"], + log["samples"], + ) + + if self.image_key == "fbank": + wav_original, wav_prediction = synth_one_sample( + inputs[index], + reconstructions[index], + labels="validation", + vocoder=self.vocoder, + ) + wav_original, wav_samples = synth_one_sample( + inputs[index], samples[index], labels="validation", vocoder=self.vocoder + ) + wav_original, wav_samples, wav_prediction = ( + wav_original[0], + wav_samples[0], + wav_prediction[0], + ) + elif self.image_key == "stft": + wav_prediction = ( + self.decode_to_waveform(reconstructions)[index, 0] + .cpu() + .detach() + .numpy() + ) + wav_samples = ( + self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy() + ) + wav_original = waveform[index, 0].cpu().detach().numpy() + + if self.logger is not None: + self.logger.experiment.log( + { + "original_%s" + % name: wandb.Audio( + wav_original, caption="original", sample_rate=self.sampling_rate + ), + "reconstruct_%s" + % name: wandb.Audio( + wav_prediction, + caption="reconstruct", + sample_rate=self.sampling_rate, + ), + "samples_%s" + % name: wandb.Audio( + wav_samples, caption="samples", sample_rate=self.sampling_rate + ), + } + ) + + return wav_original, wav_prediction, wav_samples + + def tensor2numpy(self, tensor): + return tensor.cpu().detach().numpy() + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/audioldm2/pipeline.py b/audioldm2/pipeline.py new file mode 100755 index 0000000000000000000000000000000000000000..1eec55b0198049f8baf263c3b80a7a8a0584ebeb --- /dev/null +++ b/audioldm2/pipeline.py @@ -0,0 +1,201 @@ +import os + +import yaml +import torch +import torchaudio + +from audioldm2.latent_diffusion.models.ddpm import LatentDiffusion +from audioldm2.utils import default_audioldm_config, get_metadata, download_checkpoint +from audioldm2.utilities.audio import read_wav_file +import os + +CACHE_DIR = os.getenv( + "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm2") +) + + +def seed_everything(seed): + import random, os + import numpy as np + import torch + + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + + +def text_to_filename(text): + return text.replace(" ", "_").replace("'", "_").replace('"', "_") + + +def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec): + norm_mean = -4.2677393 + norm_std = 4.5689974 + + if sampling_rate != 16000: + waveform_16k = torchaudio.functional.resample( + waveform, orig_freq=sampling_rate, new_freq=16000 + ) + else: + waveform_16k = waveform + + waveform_16k = waveform_16k - waveform_16k.mean() + fbank = torchaudio.compliance.kaldi.fbank( + waveform_16k, + htk_compat=True, + sample_frequency=16000, + use_energy=False, + window_type="hanning", + num_mel_bins=128, + dither=0.0, + frame_shift=10, + ) + + TARGET_LEN = log_mel_spec.size(0) + + # cut and pad + n_frames = fbank.shape[0] + p = TARGET_LEN - n_frames + if p > 0: + m = torch.nn.ZeroPad2d((0, 0, 0, p)) + fbank = m(fbank) + elif p < 0: + fbank = fbank[:TARGET_LEN, :] + + fbank = (fbank - norm_mean) / (norm_std * 2) + + return {"ta_kaldi_fbank": fbank} # [1024, 128] + + +def make_batch_for_text_to_audio(text, waveform=None, fbank=None, batchsize=1): + text = [text] * batchsize + if batchsize < 1: + print("Warning: Batchsize must be at least 1. Batchsize is set to .") + + if fbank is None: + fbank = torch.zeros( + (batchsize, 1024, 64) + ) # Not used, here to keep the code format + else: + fbank = torch.FloatTensor(fbank) + fbank = fbank.expand(batchsize, 1024, 64) + assert fbank.size(0) == batchsize + + stft = torch.zeros((batchsize, 1024, 512)) # Not used + + if waveform is None: + waveform = torch.zeros((batchsize, 160000)) # Not used + ta_kaldi_fbank = torch.zeros((batchsize, 1024, 128)) + else: + waveform = torch.FloatTensor(waveform) + waveform = waveform.expand(batchsize, -1) + assert waveform.size(0) == batchsize + ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, 16000, fbank) + + batch = { + "text": text, # list + "fname": [text_to_filename(t) for t in text], # list + "waveform": waveform, + "stft": stft, + "log_mel_spec": fbank, + "ta_kaldi_fbank": ta_kaldi_fbank, + } + + return batch + + +def round_up_duration(duration): + return int(round(duration / 2.5) + 1) * 2.5 + + +def split_clap_weight_to_pth(checkpoint): + if os.path.exists(os.path.join(CACHE_DIR, "clap.pth")): + return + print("Constructing the weight for the CLAP model.") + include_keys = "cond_stage_models.0.cond_stage_models.0.model." + new_state_dict = {} + for each in checkpoint["state_dict"].keys(): + if include_keys in each: + new_state_dict[each.replace(include_keys, "module.")] = checkpoint[ + "state_dict" + ][each] + torch.save({"state_dict": new_state_dict}, os.path.join(CACHE_DIR, "clap.pth")) + + +def build_model(ckpt_path=None, config=None, model_name="audioldm2-full"): + print("Loading AudioLDM-2: %s" % model_name) + + if ckpt_path is None: + ckpt_path = get_metadata()[model_name]["path"] + + if not os.path.exists(ckpt_path): + download_checkpoint(model_name) + + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + + if config is not None: + assert type(config) is str + config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) + else: + config = default_audioldm_config(model_name) + + # # Use text as condition instead of using waveform during training + config["model"]["params"]["device"] = device + # config["model"]["params"]["cond_stage_key"] = "text" + + # No normalization here + latent_diffusion = LatentDiffusion(**config["model"]["params"]) + + resume_from_checkpoint = ckpt_path + + checkpoint = torch.load(resume_from_checkpoint, map_location=device) + + latent_diffusion.load_state_dict(checkpoint["state_dict"]) + + latent_diffusion.eval() + latent_diffusion = latent_diffusion.to(device) + + return latent_diffusion + +def duration_to_latent_t_size(duration): + return int(duration * 25.6) + +def text_to_audio( + latent_diffusion, + text, + seed=42, + ddim_steps=200, + duration=10, + batchsize=1, + guidance_scale=3.5, + n_candidate_gen_per_text=3, + config=None, +): + assert ( + duration == 10 + ), "Error: Currently we only support 10 seconds of generation. Generating longer files requires some extra coding, which would be a part of the future work." + + seed_everything(int(seed)) + waveform = None + + batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize) + + latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) + + with torch.no_grad(): + waveform = latent_diffusion.generate_batch( + batch, + unconditional_guidance_scale=guidance_scale, + ddim_steps=ddim_steps, + n_gen=n_candidate_gen_per_text, + duration=duration, + ) + + return waveform diff --git a/audioldm2/utilities/__init__.py b/audioldm2/utilities/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..495e8fe675337df0afacd3a31d06d0241b6b0e63 --- /dev/null +++ b/audioldm2/utilities/__init__.py @@ -0,0 +1,3 @@ +from .tools import * +from .data import * +from .model import * diff --git a/audioldm2/utilities/audio/__init__.py b/audioldm2/utilities/audio/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..c39f9243d2d7b4fc5dea18f56b153b0f5c5bbd4c --- /dev/null +++ b/audioldm2/utilities/audio/__init__.py @@ -0,0 +1,3 @@ +from .audio_processing import * +from .stft import * +from .tools import * diff --git a/audioldm2/utilities/audio/audio_processing.py b/audioldm2/utilities/audio/audio_processing.py new file mode 100755 index 0000000000000000000000000000000000000000..77a4057aa82f226f68474f4c2a19eba84510d663 --- /dev/null +++ b/audioldm2/utilities/audio/audio_processing.py @@ -0,0 +1,100 @@ +import torch +import numpy as np +import librosa.util as librosa_util +from scipy.signal import get_window + + +def window_sumsquare( + window, + n_frames, + hop_length, + win_length, + n_fft, + dtype=np.float32, + norm=None, +): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + + n_frames : int > 0 + The number of analysis frames + + hop_length : int > 0 + The number of samples to advance between frames + + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + + n_fft : int > 0 + The length of each analysis frame. + + dtype : np.dtype + The data type of the output + + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] + return x + + +def griffin_lim(magnitudes, stft_fn, n_iters=30): + """ + PARAMS + ------ + magnitudes: spectrogram magnitudes + stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods + """ + + angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) + angles = angles.astype(np.float32) + angles = torch.autograd.Variable(torch.from_numpy(angles)) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + + for i in range(n_iters): + _, angles = stft_fn.transform(signal) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + return signal + + +def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return normalize_fun(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C diff --git a/audioldm2/utilities/audio/stft.py b/audioldm2/utilities/audio/stft.py new file mode 100755 index 0000000000000000000000000000000000000000..508f33674e6dd8a5557205c8e77e07955df13a87 --- /dev/null +++ b/audioldm2/utilities/audio/stft.py @@ -0,0 +1,178 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy.signal import get_window +from librosa.util import pad_center, tiny +from librosa.filters import mel as librosa_mel_fn + +from audioldm2.utilities.audio.audio_processing import ( + dynamic_range_compression, + dynamic_range_decompression, + window_sumsquare, +) + + +class STFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + + def __init__(self, filter_length, hop_length, win_length, window="hann"): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.forward_transform = None + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] + ) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :] + ) + + if window is not None: + assert filter_length >= win_length + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer("forward_basis", forward_basis.float()) + self.register_buffer("inverse_basis", inverse_basis.float()) + + def transform(self, input_data): + num_batches = input_data.size(0) + num_samples = input_data.size(1) + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + input_data = F.pad( + input_data.unsqueeze(1), + (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), + mode="reflect", + ) + input_data = input_data.squeeze(1) + + forward_transform = F.conv1d( + input_data, + torch.autograd.Variable(self.forward_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + ).cpu() + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part**2 + imag_part**2) + phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) + + return magnitude, phase + + def inverse(self, magnitude, phase): + recombine_magnitude_phase = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 + ) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + torch.autograd.Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + ) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, + magnitude.size(-1), + hop_length=self.hop_length, + win_length=self.win_length, + n_fft=self.filter_length, + dtype=np.float32, + ) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0] + ) + window_sum = torch.autograd.Variable( + torch.from_numpy(window_sum), requires_grad=False + ) + window_sum = window_sum + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ + approx_nonzero_indices + ] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] + inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] + + return inverse_transform + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + + +class TacotronSTFT(torch.nn.Module): + def __init__( + self, + filter_length, + hop_length, + win_length, + n_mel_channels, + sampling_rate, + mel_fmin, + mel_fmax, + ): + super(TacotronSTFT, self).__init__() + self.n_mel_channels = n_mel_channels + self.sampling_rate = sampling_rate + self.stft_fn = STFT(filter_length, hop_length, win_length) + mel_basis = librosa_mel_fn( + sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax + ) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + + def spectral_normalize(self, magnitudes, normalize_fun): + output = dynamic_range_compression(magnitudes, normalize_fun) + return output + + def spectral_de_normalize(self, magnitudes): + output = dynamic_range_decompression(magnitudes) + return output + + def mel_spectrogram(self, y, normalize_fun=torch.log): + """Computes mel-spectrograms from a batch of waves + PARAMS + ------ + y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] + + RETURNS + ------- + mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) + """ + assert torch.min(y.data) >= -1, torch.min(y.data) + assert torch.max(y.data) <= 1, torch.max(y.data) + + magnitudes, phases = self.stft_fn.transform(y) + magnitudes = magnitudes.data + mel_output = torch.matmul(self.mel_basis, magnitudes) + mel_output = self.spectral_normalize(mel_output, normalize_fun) + energy = torch.norm(magnitudes, dim=1) + + return mel_output, magnitudes, phases, energy diff --git a/audioldm2/utilities/audio/tools.py b/audioldm2/utilities/audio/tools.py new file mode 100755 index 0000000000000000000000000000000000000000..8c666a7c67e0ae93edbad666520fd2e98fd29d18 --- /dev/null +++ b/audioldm2/utilities/audio/tools.py @@ -0,0 +1,69 @@ +import torch +import numpy as np +from scipy.io.wavfile import write +import torchaudio + +from audioldm2.utilities.audio.audio_processing import griffin_lim + + +def pad_wav(waveform, segment_length): + waveform_length = waveform.shape[-1] + assert waveform_length > 100, "Waveform is too short, %s" % waveform_length + if segment_length is None or waveform_length == segment_length: + return waveform + elif waveform_length > segment_length: + return waveform[:segment_length] + elif waveform_length < segment_length: + temp_wav = np.zeros((1, segment_length)) + temp_wav[:, :waveform_length] = waveform + return temp_wav + + +def normalize_wav(waveform): + waveform = waveform - np.mean(waveform) + waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) + return waveform * 0.5 + + +def read_wav_file(filename, segment_length): + # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower + waveform, sr = torchaudio.load(filename) # Faster!!! + waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) + waveform = waveform.numpy()[0, ...] + waveform = normalize_wav(waveform) + waveform = waveform[None, ...] + waveform = pad_wav(waveform, segment_length) + + waveform = waveform / np.max(np.abs(waveform)) + waveform = 0.5 * waveform + + return waveform + + +def get_mel_from_wav(audio, _stft): + audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) + audio = torch.autograd.Variable(audio, requires_grad=False) + melspec, magnitudes, phases, energy = _stft.mel_spectrogram(audio) + melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) + magnitudes = torch.squeeze(magnitudes, 0).numpy().astype(np.float32) + energy = torch.squeeze(energy, 0).numpy().astype(np.float32) + return melspec, magnitudes, energy + + +def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60): + mel = torch.stack([mel]) + mel_decompress = _stft.spectral_de_normalize(mel) + mel_decompress = mel_decompress.transpose(1, 2).data.cpu() + spec_from_mel_scaling = 1000 + spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) + spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) + spec_from_mel = spec_from_mel * spec_from_mel_scaling + + audio = griffin_lim( + torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters + ) + + audio = audio.squeeze() + audio = audio.cpu().numpy() + audio_path = out_filename + write(audio_path, _stft.sampling_rate, audio) diff --git a/audioldm2/utilities/data/__init__.py b/audioldm2/utilities/data/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..13a9804e72b88e3b9078940aee87db73788c1fb5 --- /dev/null +++ b/audioldm2/utilities/data/__init__.py @@ -0,0 +1 @@ +from .dataset import Dataset diff --git a/audioldm2/utilities/data/add_on.py b/audioldm2/utilities/data/add_on.py new file mode 100755 index 0000000000000000000000000000000000000000..4cfc6297e2f66759077c1540fc04b19560f3659c --- /dev/null +++ b/audioldm2/utilities/data/add_on.py @@ -0,0 +1,508 @@ +import os +import torch +import numpy as np +import torchaudio +import matplotlib.pyplot as plt + +CACHE = { + "get_vits_phoneme_ids": { + "PAD_LENGTH": 310, + "_pad": "_", + "_punctuation": ';:,.!?¡¿—…"«»“” ', + "_letters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", + "_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ", + "_special": "♪☎☒☝⚠", + } +} + +CACHE["get_vits_phoneme_ids"]["symbols"] = ( + [CACHE["get_vits_phoneme_ids"]["_pad"]] + + list(CACHE["get_vits_phoneme_ids"]["_punctuation"]) + + list(CACHE["get_vits_phoneme_ids"]["_letters"]) + + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"]) + + list(CACHE["get_vits_phoneme_ids"]["_special"]) +) +CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = { + s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"]) +} + + +def get_vits_phoneme_ids(config, dl_output, metadata): + pad_token_id = 0 + pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"] + _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] + + assert ( + "phonemes" in metadata.keys() + ), "You must provide vits phonemes on using addon get_vits_phoneme_ids" + clean_text = metadata["phonemes"] + sequence = [] + + for symbol in clean_text: + symbol_id = _symbol_to_id[symbol] + sequence += [symbol_id] + + inserted_zero_sequence = [0] * (len(sequence) * 2) + inserted_zero_sequence[1::2] = sequence + inserted_zero_sequence = inserted_zero_sequence + [0] + + def _pad_phonemes(phonemes_list): + return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list)) + + return {"phoneme_idx": torch.LongTensor(_pad_phonemes(inserted_zero_sequence))} + + +def get_vits_phoneme_ids_no_padding(config, dl_output, metadata): + pad_token_id = 0 + pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"] + _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] + + assert ( + "phonemes" in metadata.keys() + ), "You must provide vits phonemes on using addon get_vits_phoneme_ids" + clean_text = metadata["phonemes"] + "⚠" + sequence = [] + + for symbol in clean_text: + if symbol not in _symbol_to_id.keys(): + print("%s is not in the vocabulary. %s" % (symbol, clean_text)) + symbol = "_" + symbol_id = _symbol_to_id[symbol] + sequence += [symbol_id] + + def _pad_phonemes(phonemes_list): + return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list)) + + sequence = sequence[:pad_length] + + return {"phoneme_idx": torch.LongTensor(_pad_phonemes(sequence))} + + +def calculate_relative_bandwidth(config, dl_output, metadata): + assert "stft" in dl_output.keys() + + # The last dimension of the stft feature is the frequency dimension + freq_dimensions = dl_output["stft"].size(-1) + + freq_energy_dist = torch.sum(dl_output["stft"], dim=0) + freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0) + total_energy = freq_energy_dist[-1] + + percentile_5th = total_energy * 0.05 + percentile_95th = total_energy * 0.95 + + lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist)) + higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist)) + + lower_idx = int((lower_idx / freq_dimensions) * 1000) + higher_idx = int((higher_idx / freq_dimensions) * 1000) + + return {"freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx])} + + +def calculate_mel_spec_relative_bandwidth_as_extra_channel(config, dl_output, metadata): + assert "stft" in dl_output.keys() + linear_mel_spec = torch.exp(torch.clip(dl_output["log_mel_spec"], max=10)) + + # The last dimension of the stft feature is the frequency dimension + freq_dimensions = linear_mel_spec.size(-1) + freq_energy_dist = torch.sum(linear_mel_spec, dim=0) + freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0) + total_energy = freq_energy_dist[-1] + + percentile_5th = total_energy * 0.05 + percentile_95th = total_energy * 0.95 + + lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist)) + higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist)) + + latent_t_size = config["model"]["params"]["latent_t_size"] + latent_f_size = config["model"]["params"]["latent_f_size"] + + lower_idx = int(latent_f_size * float((lower_idx / freq_dimensions))) + higher_idx = int(latent_f_size * float((higher_idx / freq_dimensions))) + + bandwidth_condition = torch.zeros((latent_t_size, latent_f_size)) + bandwidth_condition[:, lower_idx:higher_idx] += 1.0 + + return { + "mel_spec_bandwidth_cond_extra_channel": bandwidth_condition, + "freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx]), + } + + +def waveform_rs_48k(config, dl_output, metadata): + waveform = dl_output["waveform"] # [1, samples] + sampling_rate = dl_output["sampling_rate"] + + if sampling_rate != 48000: + waveform_48k = torchaudio.functional.resample( + waveform, orig_freq=sampling_rate, new_freq=48000 + ) + else: + waveform_48k = waveform + + return {"waveform_48k": waveform_48k} + + +def extract_vits_phoneme_and_flant5_text(config, dl_output, metadata): + assert ( + "phoneme" not in metadata.keys() + ), "The metadata of speech you use seems belong to fastspeech. Please check dataset_root.json" + + if "phonemes" in metadata.keys(): + new_item = get_vits_phoneme_ids_no_padding(config, dl_output, metadata) + new_item["text"] = "" # We assume TTS data does not have text description + else: + fake_metadata = {"phonemes": ""} # Add empty phoneme sequence + new_item = get_vits_phoneme_ids_no_padding(config, dl_output, fake_metadata) + + return new_item + + +def extract_fs2_phoneme_and_flant5_text(config, dl_output, metadata): + if "phoneme" in metadata.keys(): + new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata) + new_item["text"] = "" + else: + fake_metadata = {"phoneme": []} + new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, fake_metadata) + return new_item + + +def extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata): + PAD_LENGTH = 135 + + phonemes_lookup_dict = { + "K": 0, + "IH2": 1, + "NG": 2, + "OW2": 3, + "AH2": 4, + "F": 5, + "AE0": 6, + "IY0": 7, + "SH": 8, + "G": 9, + "W": 10, + "UW1": 11, + "AO2": 12, + "AW2": 13, + "UW0": 14, + "EY2": 15, + "UW2": 16, + "AE2": 17, + "IH0": 18, + "P": 19, + "D": 20, + "ER1": 21, + "AA1": 22, + "EH0": 23, + "UH1": 24, + "N": 25, + "V": 26, + "AY1": 27, + "EY1": 28, + "UH2": 29, + "EH1": 30, + "L": 31, + "AA2": 32, + "R": 33, + "OY1": 34, + "Y": 35, + "ER2": 36, + "S": 37, + "AE1": 38, + "AH1": 39, + "JH": 40, + "ER0": 41, + "EH2": 42, + "IY2": 43, + "OY2": 44, + "AW1": 45, + "IH1": 46, + "IY1": 47, + "OW0": 48, + "AO0": 49, + "AY0": 50, + "EY0": 51, + "AY2": 52, + "UH0": 53, + "M": 54, + "TH": 55, + "T": 56, + "OY0": 57, + "AW0": 58, + "DH": 59, + "Z": 60, + "spn": 61, + "AH0": 62, + "sp": 63, + "AO1": 64, + "OW1": 65, + "ZH": 66, + "B": 67, + "AA0": 68, + "CH": 69, + "HH": 70, + } + pad_token_id = len(phonemes_lookup_dict.keys()) + + assert ( + "phoneme" in metadata.keys() + ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset" + + phonemes = [ + phonemes_lookup_dict[x] + for x in metadata["phoneme"] + if (x in phonemes_lookup_dict.keys()) + ] + + if (len(phonemes) / PAD_LENGTH) > 5: + print( + "Warning: Phonemes length is too long and is truncated too much! %s" + % metadata + ) + + phonemes = phonemes[:PAD_LENGTH] + + def _pad_phonemes(phonemes_list): + return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list)) + + return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))} + + +def extract_phoneme_g2p_en_feature(config, dl_output, metadata): + PAD_LENGTH = 250 + + phonemes_lookup_dict = { + " ": 0, + "AA": 1, + "AE": 2, + "AH": 3, + "AO": 4, + "AW": 5, + "AY": 6, + "B": 7, + "CH": 8, + "D": 9, + "DH": 10, + "EH": 11, + "ER": 12, + "EY": 13, + "F": 14, + "G": 15, + "HH": 16, + "IH": 17, + "IY": 18, + "JH": 19, + "K": 20, + "L": 21, + "M": 22, + "N": 23, + "NG": 24, + "OW": 25, + "OY": 26, + "P": 27, + "R": 28, + "S": 29, + "SH": 30, + "T": 31, + "TH": 32, + "UH": 33, + "UW": 34, + "V": 35, + "W": 36, + "Y": 37, + "Z": 38, + "ZH": 39, + } + pad_token_id = len(phonemes_lookup_dict.keys()) + + assert ( + "phoneme" in metadata.keys() + ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset" + phonemes = [ + phonemes_lookup_dict[x] + for x in metadata["phoneme"] + if (x in phonemes_lookup_dict.keys()) + ] + + if (len(phonemes) / PAD_LENGTH) > 5: + print( + "Warning: Phonemes length is too long and is truncated too much! %s" + % metadata + ) + + phonemes = phonemes[:PAD_LENGTH] + + def _pad_phonemes(phonemes_list): + return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list)) + + return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))} + + +def extract_kaldi_fbank_feature(config, dl_output, metadata): + norm_mean = -4.2677393 + norm_std = 4.5689974 + + waveform = dl_output["waveform"] # [1, samples] + sampling_rate = dl_output["sampling_rate"] + log_mel_spec_hifigan = dl_output["log_mel_spec"] + + if sampling_rate != 16000: + waveform_16k = torchaudio.functional.resample( + waveform, orig_freq=sampling_rate, new_freq=16000 + ) + else: + waveform_16k = waveform + + waveform_16k = waveform_16k - waveform_16k.mean() + fbank = torchaudio.compliance.kaldi.fbank( + waveform_16k, + htk_compat=True, + sample_frequency=16000, + use_energy=False, + window_type="hanning", + num_mel_bins=128, + dither=0.0, + frame_shift=10, + ) + + TARGET_LEN = log_mel_spec_hifigan.size(0) + + # cut and pad + n_frames = fbank.shape[0] + p = TARGET_LEN - n_frames + if p > 0: + m = torch.nn.ZeroPad2d((0, 0, 0, p)) + fbank = m(fbank) + elif p < 0: + fbank = fbank[:TARGET_LEN, :] + + fbank = (fbank - norm_mean) / (norm_std * 2) + + return {"ta_kaldi_fbank": fbank} # [1024, 128] + + +def extract_kaldi_fbank_feature_32k(config, dl_output, metadata): + norm_mean = -4.2677393 + norm_std = 4.5689974 + + waveform = dl_output["waveform"] # [1, samples] + sampling_rate = dl_output["sampling_rate"] + log_mel_spec_hifigan = dl_output["log_mel_spec"] + + if sampling_rate != 32000: + waveform_32k = torchaudio.functional.resample( + waveform, orig_freq=sampling_rate, new_freq=32000 + ) + else: + waveform_32k = waveform + + waveform_32k = waveform_32k - waveform_32k.mean() + fbank = torchaudio.compliance.kaldi.fbank( + waveform_32k, + htk_compat=True, + sample_frequency=32000, + use_energy=False, + window_type="hanning", + num_mel_bins=128, + dither=0.0, + frame_shift=10, + ) + + TARGET_LEN = log_mel_spec_hifigan.size(0) + + # cut and pad + n_frames = fbank.shape[0] + p = TARGET_LEN - n_frames + if p > 0: + m = torch.nn.ZeroPad2d((0, 0, 0, p)) + fbank = m(fbank) + elif p < 0: + fbank = fbank[:TARGET_LEN, :] + + fbank = (fbank - norm_mean) / (norm_std * 2) + + return {"ta_kaldi_fbank": fbank} # [1024, 128] + + +# Use the beat and downbeat information as music conditions +def extract_drum_beat(config, dl_output, metadata): + def visualization(conditional_signal, mel_spectrogram, filename): + import soundfile as sf + + sf.write( + os.path.basename(dl_output["fname"]), + np.array(dl_output["waveform"])[0], + dl_output["sampling_rate"], + ) + plt.figure(figsize=(10, 10)) + + plt.subplot(211) + plt.imshow(np.array(conditional_signal).T, aspect="auto") + plt.title("Conditional Signal") + + plt.subplot(212) + plt.imshow(np.array(mel_spectrogram).T, aspect="auto") + plt.title("Mel Spectrogram") + + plt.savefig(filename) + plt.close() + + assert "sample_rate" in metadata and "beat" in metadata and "downbeat" in metadata + + sampling_rate = metadata["sample_rate"] + duration = dl_output["duration"] + # The dataloader segment length before performing torch resampling + original_segment_length_before_resample = int(sampling_rate * duration) + + random_start_sample = int(dl_output["random_start_sample_in_original_audio_file"]) + + # The sample idx for beat and downbeat, relatively to the segmented audio + beat = [ + x - random_start_sample + for x in metadata["beat"] + if ( + x - random_start_sample >= 0 + and x - random_start_sample <= original_segment_length_before_resample + ) + ] + downbeat = [ + x - random_start_sample + for x in metadata["downbeat"] + if ( + x - random_start_sample >= 0 + and x - random_start_sample <= original_segment_length_before_resample + ) + ] + + latent_shape = ( + config["model"]["params"]["latent_t_size"], + config["model"]["params"]["latent_f_size"], + ) + conditional_signal = torch.zeros(latent_shape) + + # beat: -0.5 + # downbeat: +1.0 + # 0: none; -0.5: beat; 1.0: downbeat; 0.5: downbeat+beat + for each in beat: + beat_index = int( + (each / original_segment_length_before_resample) * latent_shape[0] + ) + beat_index = min(beat_index, conditional_signal.size(0) - 1) + + conditional_signal[beat_index, :] -= 0.5 + + for each in downbeat: + beat_index = int( + (each / original_segment_length_before_resample) * latent_shape[0] + ) + beat_index = min(beat_index, conditional_signal.size(0) - 1) + + conditional_signal[beat_index, :] += 1.0 + + # visualization(conditional_signal, dl_output["log_mel_spec"], filename = os.path.basename(dl_output["fname"])+".png") + + return {"cond_beat_downbeat": conditional_signal} diff --git a/audioldm2/utilities/data/dataset.py b/audioldm2/utilities/data/dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..f0bfbb7388ca6473beb4574ac4e29dcf0b7c0571 --- /dev/null +++ b/audioldm2/utilities/data/dataset.py @@ -0,0 +1,518 @@ +import os +import pandas as pd + +import audioldm2.utilities.audio as Audio +from audioldm2.utilities.tools import load_json + +import random +from torch.utils.data import Dataset +import torch.nn.functional +import torch +import numpy as np +import torchaudio + + +class AudioDataset(Dataset): + def __init__( + self, + config=None, + split="train", + waveform_only=False, + add_ons=[], + dataset_json_path=None, # + ): + """ + Dataset that manages audio recordings + :param audio_conf: Dictionary containing the audio loading and preprocessing settings + :param dataset_json_file + """ + self.config = config + self.split = split + self.pad_wav_start_sample = 0 # If none, random choose + self.trim_wav = False + self.waveform_only = waveform_only + self.add_ons = [eval(x) for x in add_ons] + print("Add-ons:", self.add_ons) + + self.build_setting_parameters() + + # For an external dataset + if dataset_json_path is not None: + assert type(dataset_json_path) == str + print("Load metadata from %s" % dataset_json_path) + self.data = load_json(dataset_json_path)["data"] + self.id2label, self.index_dict, self.num2label = {}, {}, {} + else: + self.metadata_root = load_json(self.config["metadata_root"]) + self.dataset_name = self.config["data"][self.split] + assert split in self.config["data"].keys(), ( + "The dataset split %s you specified is not present in the config. You can choose from %s" + % (split, self.config["data"].keys()) + ) + self.build_dataset() + self.build_id_to_label() + + self.build_dsp() + self.label_num = len(self.index_dict) + print("Dataset initialize finished") + + def __getitem__(self, index): + ( + fname, + waveform, + stft, + log_mel_spec, + label_vector, # the one-hot representation of the audio class + # the metadata of the sampled audio file and the mixup audio file (if exist) + (datum, mix_datum), + random_start, + ) = self.feature_extraction(index) + text = self.get_sample_text_caption(datum, mix_datum, label_vector) + + data = { + "text": text, # list + "fname": self.text_to_filename(text) + if (len(fname) == 0) + else fname, # list + # tensor, [batchsize, class_num] + "label_vector": "" if (label_vector is None) else label_vector.float(), + # tensor, [batchsize, 1, samples_num] + "waveform": "" if (waveform is None) else waveform.float(), + # tensor, [batchsize, t-steps, f-bins] + "stft": "" if (stft is None) else stft.float(), + # tensor, [batchsize, t-steps, mel-bins] + "log_mel_spec": "" if (log_mel_spec is None) else log_mel_spec.float(), + "duration": self.duration, + "sampling_rate": self.sampling_rate, + "random_start_sample_in_original_audio_file": random_start, + } + + for add_on in self.add_ons: + data.update(add_on(self.config, data, self.data[index])) + + if data["text"] is None: + print("Warning: The model return None on key text", fname) + data["text"] = "" + + return data + + def text_to_filename(self, text): + return text.replace(" ", "_").replace("'", "_").replace('"', "_") + + def get_dataset_root_path(self, dataset): + assert dataset in self.metadata_root.keys() + return self.metadata_root[dataset] + + def get_dataset_metadata_path(self, dataset, key): + # key: train, test, val, class_label_indices + try: + if dataset in self.metadata_root["metadata"]["path"].keys(): + return self.metadata_root["metadata"]["path"][dataset][key] + except: + raise ValueError( + 'Dataset %s does not metadata "%s" specified' % (dataset, key) + ) + # return None + + def __len__(self): + return len(self.data) + + def feature_extraction(self, index): + if index > len(self.data) - 1: + print( + "The index of the dataloader is out of range: %s/%s" + % (index, len(self.data)) + ) + index = random.randint(0, len(self.data) - 1) + + # Read wave file and extract feature + while True: + try: + label_indices = np.zeros(self.label_num, dtype=np.float32) + datum = self.data[index] + ( + log_mel_spec, + stft, + mix_lambda, + waveform, + random_start, + ) = self.read_audio_file(datum["wav"]) + mix_datum = None + if self.label_num > 0 and "labels" in datum.keys(): + for label_str in datum["labels"].split(","): + label_indices[int(self.index_dict[label_str])] = 1.0 + + # If the key "label" is not in the metadata, return all zero vector + label_indices = torch.FloatTensor(label_indices) + break + except Exception as e: + index = (index + 1) % len(self.data) + print( + "Error encounter during audio feature extraction: ", e, datum["wav"] + ) + continue + + # The filename of the wav file + fname = datum["wav"] + # t_step = log_mel_spec.size(0) + # waveform = torch.FloatTensor(waveform[..., : int(self.hopsize * t_step)]) + waveform = torch.FloatTensor(waveform) + + return ( + fname, + waveform, + stft, + log_mel_spec, + label_indices, + (datum, mix_datum), + random_start, + ) + + # def augmentation(self, log_mel_spec): + # assert torch.min(log_mel_spec) < 0 + # log_mel_spec = log_mel_spec.exp() + + # log_mel_spec = torch.transpose(log_mel_spec, 0, 1) + # # this is just to satisfy new torchaudio version. + # log_mel_spec = log_mel_spec.unsqueeze(0) + # if self.freqm != 0: + # log_mel_spec = self.frequency_masking(log_mel_spec, self.freqm) + # if self.timem != 0: + # log_mel_spec = self.time_masking( + # log_mel_spec, self.timem) # self.timem=0 + + # log_mel_spec = (log_mel_spec + 1e-7).log() + # # squeeze back + # log_mel_spec = log_mel_spec.squeeze(0) + # log_mel_spec = torch.transpose(log_mel_spec, 0, 1) + # return log_mel_spec + + def build_setting_parameters(self): + # Read from the json config + self.melbins = self.config["preprocessing"]["mel"]["n_mel_channels"] + # self.freqm = self.config["preprocessing"]["mel"]["freqm"] + # self.timem = self.config["preprocessing"]["mel"]["timem"] + self.sampling_rate = self.config["preprocessing"]["audio"]["sampling_rate"] + self.hopsize = self.config["preprocessing"]["stft"]["hop_length"] + self.duration = self.config["preprocessing"]["audio"]["duration"] + self.target_length = int(self.duration * self.sampling_rate / self.hopsize) + + self.mixup = self.config["augmentation"]["mixup"] + + # Calculate parameter derivations + # self.waveform_sample_length = int(self.target_length * self.hopsize) + + # if (self.config["balance_sampling_weight"]): + # self.samples_weight = np.loadtxt( + # self.config["balance_sampling_weight"], delimiter="," + # ) + + if "train" not in self.split: + self.mixup = 0.0 + # self.freqm = 0 + # self.timem = 0 + + def _relative_path_to_absolute_path(self, metadata, dataset_name): + root_path = self.get_dataset_root_path(dataset_name) + for i in range(len(metadata["data"])): + assert "wav" in metadata["data"][i].keys(), metadata["data"][i] + assert metadata["data"][i]["wav"][0] != "/", ( + "The dataset metadata should only contain relative path to the audio file: " + + str(metadata["data"][i]["wav"]) + ) + metadata["data"][i]["wav"] = os.path.join( + root_path, metadata["data"][i]["wav"] + ) + return metadata + + def build_dataset(self): + self.data = [] + print("Build dataset split %s from %s" % (self.split, self.dataset_name)) + if type(self.dataset_name) is str: + data_json = load_json( + self.get_dataset_metadata_path(self.dataset_name, key=self.split) + ) + data_json = self._relative_path_to_absolute_path( + data_json, self.dataset_name + ) + self.data = data_json["data"] + elif type(self.dataset_name) is list: + for dataset_name in self.dataset_name: + data_json = load_json( + self.get_dataset_metadata_path(dataset_name, key=self.split) + ) + data_json = self._relative_path_to_absolute_path( + data_json, dataset_name + ) + self.data += data_json["data"] + else: + raise Exception("Invalid data format") + print("Data size: {}".format(len(self.data))) + + def build_dsp(self): + self.STFT = Audio.stft.TacotronSTFT( + self.config["preprocessing"]["stft"]["filter_length"], + self.config["preprocessing"]["stft"]["hop_length"], + self.config["preprocessing"]["stft"]["win_length"], + self.config["preprocessing"]["mel"]["n_mel_channels"], + self.config["preprocessing"]["audio"]["sampling_rate"], + self.config["preprocessing"]["mel"]["mel_fmin"], + self.config["preprocessing"]["mel"]["mel_fmax"], + ) + # self.stft_transform = torchaudio.transforms.Spectrogram( + # n_fft=1024, hop_length=160 + # ) + # self.melscale_transform = torchaudio.transforms.MelScale( + # sample_rate=16000, n_stft=1024 // 2 + 1, n_mels=64 + # ) + + def build_id_to_label(self): + id2label = {} + id2num = {} + num2label = {} + class_label_indices_path = self.get_dataset_metadata_path( + dataset=self.config["data"]["class_label_indices"], + key="class_label_indices", + ) + if class_label_indices_path is not None: + df = pd.read_csv(class_label_indices_path) + for _, row in df.iterrows(): + index, mid, display_name = row["index"], row["mid"], row["display_name"] + id2label[mid] = display_name + id2num[mid] = index + num2label[index] = display_name + self.id2label, self.index_dict, self.num2label = id2label, id2num, num2label + else: + self.id2label, self.index_dict, self.num2label = {}, {}, {} + + def resample(self, waveform, sr): + waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate) + # waveform = librosa.resample(waveform, sr, self.sampling_rate) + return waveform + + # if sr == 16000: + # return waveform + # if sr == 32000 and self.sampling_rate == 16000: + # waveform = waveform[::2] + # return waveform + # if sr == 48000 and self.sampling_rate == 16000: + # waveform = waveform[::3] + # return waveform + # else: + # raise ValueError( + # "We currently only support 16k audio generation. You need to resample you audio file to 16k, 32k, or 48k: %s, %s" + # % (sr, self.sampling_rate) + # ) + + def normalize_wav(self, waveform): + waveform = waveform - np.mean(waveform) + waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) + return waveform * 0.5 # Manually limit the maximum amplitude into 0.5 + + def random_segment_wav(self, waveform, target_length): + waveform_length = waveform.shape[-1] + assert waveform_length > 100, "Waveform is too short, %s" % waveform_length + + # Too short + if (waveform_length - target_length) <= 0: + return waveform, 0 + + random_start = int(self.random_uniform(0, waveform_length - target_length)) + return waveform[:, random_start : random_start + target_length], random_start + + def pad_wav(self, waveform, target_length): + waveform_length = waveform.shape[-1] + assert waveform_length > 100, "Waveform is too short, %s" % waveform_length + + if waveform_length == target_length: + return waveform + + # Pad + temp_wav = np.zeros((1, target_length), dtype=np.float32) + if self.pad_wav_start_sample is None: + rand_start = int(self.random_uniform(0, target_length - waveform_length)) + else: + rand_start = 0 + + temp_wav[:, rand_start : rand_start + waveform_length] = waveform + return temp_wav + + def trim_wav(self, waveform): + if np.max(np.abs(waveform)) < 0.0001: + return waveform + + def detect_leading_silence(waveform, threshold=0.0001): + chunk_size = 1000 + waveform_length = waveform.shape[0] + start = 0 + while start + chunk_size < waveform_length: + if np.max(np.abs(waveform[start : start + chunk_size])) < threshold: + start += chunk_size + else: + break + return start + + def detect_ending_silence(waveform, threshold=0.0001): + chunk_size = 1000 + waveform_length = waveform.shape[0] + start = waveform_length + while start - chunk_size > 0: + if np.max(np.abs(waveform[start - chunk_size : start])) < threshold: + start -= chunk_size + else: + break + if start == waveform_length: + return start + else: + return start + chunk_size + + start = detect_leading_silence(waveform) + end = detect_ending_silence(waveform) + + return waveform[start:end] + + def read_wav_file(self, filename): + # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower + waveform, sr = torchaudio.load(filename) + + waveform, random_start = self.random_segment_wav( + waveform, target_length=int(sr * self.duration) + ) + + waveform = self.resample(waveform, sr) + # random_start = int(random_start * (self.sampling_rate / sr)) + + waveform = waveform.numpy()[0, ...] + + waveform = self.normalize_wav(waveform) + + if self.trim_wav: + waveform = self.trim_wav(waveform) + + waveform = waveform[None, ...] + waveform = self.pad_wav( + waveform, target_length=int(self.sampling_rate * self.duration) + ) + return waveform, random_start + + def mix_two_waveforms(self, waveform1, waveform2): + mix_lambda = np.random.beta(5, 5) + mix_waveform = mix_lambda * waveform1 + (1 - mix_lambda) * waveform2 + return self.normalize_wav(mix_waveform), mix_lambda + + def read_audio_file(self, filename, filename2=None): + if os.path.exists(filename): + waveform, random_start = self.read_wav_file(filename) + else: + print( + 'Warning [dataset.py]: The wav path "', + filename, + '" is not find in the metadata. Use empty waveform instead.', + ) + target_length = int(self.sampling_rate * self.duration) + waveform = torch.zeros((1, target_length)) + random_start = 0 + + mix_lambda = 0.0 + # log_mel_spec, stft = self.wav_feature_extraction_torchaudio(waveform) # this line is faster, but this implementation is not aligned with HiFi-GAN + if not self.waveform_only: + log_mel_spec, stft = self.wav_feature_extraction(waveform) + else: + # Load waveform data only + # Use zero array to keep the format unified + log_mel_spec, stft = None, None + + return log_mel_spec, stft, mix_lambda, waveform, random_start + + def get_sample_text_caption(self, datum, mix_datum, label_indices): + text = self.label_indices_to_text(datum, label_indices) + if mix_datum is not None: + text += " " + self.label_indices_to_text(mix_datum, label_indices) + return text + + # This one is significantly slower than "wav_feature_extraction_torchaudio" if num_worker > 1 + def wav_feature_extraction(self, waveform): + waveform = waveform[0, ...] + waveform = torch.FloatTensor(waveform) + + log_mel_spec, stft, energy = Audio.tools.get_mel_from_wav(waveform, self.STFT) + + log_mel_spec = torch.FloatTensor(log_mel_spec.T) + stft = torch.FloatTensor(stft.T) + + log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft) + return log_mel_spec, stft + + # @profile + # def wav_feature_extraction_torchaudio(self, waveform): + # waveform = waveform[0, ...] + # waveform = torch.FloatTensor(waveform) + + # stft = self.stft_transform(waveform) + # mel_spec = self.melscale_transform(stft) + # log_mel_spec = torch.log(mel_spec + 1e-7) + + # log_mel_spec = torch.FloatTensor(log_mel_spec.T) + # stft = torch.FloatTensor(stft.T) + + # log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft) + # return log_mel_spec, stft + + def pad_spec(self, log_mel_spec): + n_frames = log_mel_spec.shape[0] + p = self.target_length - n_frames + # cut and pad + if p > 0: + m = torch.nn.ZeroPad2d((0, 0, 0, p)) + log_mel_spec = m(log_mel_spec) + elif p < 0: + log_mel_spec = log_mel_spec[0 : self.target_length, :] + + if log_mel_spec.size(-1) % 2 != 0: + log_mel_spec = log_mel_spec[..., :-1] + + return log_mel_spec + + def _read_datum_caption(self, datum): + caption_keys = [x for x in datum.keys() if ("caption" in x)] + random_index = torch.randint(0, len(caption_keys), (1,))[0].item() + return datum[caption_keys[random_index]] + + def _is_contain_caption(self, datum): + caption_keys = [x for x in datum.keys() if ("caption" in x)] + return len(caption_keys) > 0 + + def label_indices_to_text(self, datum, label_indices): + if self._is_contain_caption(datum): + return self._read_datum_caption(datum) + elif "label" in datum.keys(): + name_indices = torch.where(label_indices > 0.1)[0] + # description_header = "This audio contains the sound of " + description_header = "" + labels = "" + for id, each in enumerate(name_indices): + if id == len(name_indices) - 1: + labels += "%s." % self.num2label[int(each)] + else: + labels += "%s, " % self.num2label[int(each)] + return description_header + labels + else: + return "" # TODO, if both label and caption are not provided, return empty string + + def random_uniform(self, start, end): + val = torch.rand(1).item() + return start + (end - start) * val + + def frequency_masking(self, log_mel_spec, freqm): + bs, freq, tsteps = log_mel_spec.size() + mask_len = int(self.random_uniform(freqm // 8, freqm)) + mask_start = int(self.random_uniform(start=0, end=freq - mask_len)) + log_mel_spec[:, mask_start : mask_start + mask_len, :] *= 0.0 + return log_mel_spec + + def time_masking(self, log_mel_spec, timem): + bs, freq, tsteps = log_mel_spec.size() + mask_len = int(self.random_uniform(timem // 8, timem)) + mask_start = int(self.random_uniform(start=0, end=tsteps - mask_len)) + log_mel_spec[:, :, mask_start : mask_start + mask_len] *= 0.0 + return log_mel_spec diff --git a/audioldm2/utilities/model.py b/audioldm2/utilities/model.py new file mode 100755 index 0000000000000000000000000000000000000000..ffefac1212b85bfb8c4992371dbdf6d500a969e3 --- /dev/null +++ b/audioldm2/utilities/model.py @@ -0,0 +1,121 @@ +import torch + +import audioldm2.hifigan as hifigan + + +def get_vocoder_config(): + return { + "resblock": "1", + "num_gpus": 6, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], + "upsample_initial_channel": 1024, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "segment_size": 8192, + "num_mels": 64, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 160, + "win_size": 1024, + "sampling_rate": 16000, + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": None, + "num_workers": 4, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, + } + + +def get_available_checkpoint_keys(model, ckpt): + state_dict = torch.load(ckpt)["state_dict"] + current_state_dict = model.state_dict() + new_state_dict = {} + for k in state_dict.keys(): + if ( + k in current_state_dict.keys() + and current_state_dict[k].size() == state_dict[k].size() + ): + new_state_dict[k] = state_dict[k] + else: + print("==> WARNING: Skipping %s" % k) + print( + "%s out of %s keys are matched" + % (len(new_state_dict.keys()), len(state_dict.keys())) + ) + return new_state_dict + + +def get_param_num(model): + num_param = sum(param.numel() for param in model.parameters()) + return num_param + + +def torch_version_orig_mod_remove(state_dict): + new_state_dict = {} + new_state_dict["generator"] = {} + for key in state_dict["generator"].keys(): + if "_orig_mod." in key: + new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[ + "generator" + ][key] + else: + new_state_dict["generator"][key] = state_dict["generator"][key] + return new_state_dict + + +def get_vocoder(config, device, mel_bins): + name = "HiFi-GAN" + speaker = "" + if name == "MelGAN": + if speaker == "LJSpeech": + vocoder = torch.hub.load( + "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" + ) + elif speaker == "universal": + vocoder = torch.hub.load( + "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" + ) + vocoder.mel2wav.eval() + vocoder.mel2wav.to(device) + elif name == "HiFi-GAN": + config = get_vocoder_config() + config = hifigan.AttrDict(config) + vocoder = hifigan.Generator_old(config) + # print("Load hifigan/g_01080000") + # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) + # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) + # ckpt = torch_version_orig_mod_remove(ckpt) + # vocoder.load_state_dict(ckpt["generator"]) + vocoder.eval() + vocoder.remove_weight_norm() + vocoder.to(device) + return vocoder + + +def vocoder_infer(mels, vocoder, lengths=None): + with torch.no_grad(): + wavs = vocoder(mels).squeeze(1) + + wavs = (wavs.cpu().numpy() * 32768).astype("int16") + + if lengths is not None: + wavs = wavs[:, :lengths] + + # wavs = [wav for wav in wavs] + + # for i in range(len(mels)): + # if lengths is not None: + # wavs[i] = wavs[i][: lengths[i]] + + return wavs diff --git a/audioldm2/utilities/sampler.py b/audioldm2/utilities/sampler.py new file mode 100755 index 0000000000000000000000000000000000000000..cdaf4882715f53f39ead8bf71fb3dccc29cd8b94 --- /dev/null +++ b/audioldm2/utilities/sampler.py @@ -0,0 +1,588 @@ +from typing import Iterator, List, Optional, Union +from collections import Counter +import logging +from operator import itemgetter +import random + +import numpy as np + +from torch.utils.data import DistributedSampler +from torch.utils.data.sampler import Sampler + +LOGGER = logging.getLogger(__name__) + +from torch.utils.data import Dataset, Sampler + + +class DatasetFromSampler(Dataset): + """Dataset to create indexes from `Sampler`. + Args: + sampler: PyTorch sampler + """ + + def __init__(self, sampler: Sampler): + """Initialisation for DatasetFromSampler.""" + self.sampler = sampler + self.sampler_list = None + + def __getitem__(self, index: int): + """Gets element of the dataset. + Args: + index: index of the element in the dataset + Returns: + Single element by index + """ + if self.sampler_list is None: + self.sampler_list = list(self.sampler) + return self.sampler_list[index] + + def __len__(self) -> int: + """ + Returns: + int: length of the dataset + """ + return len(self.sampler) + + +class BalanceClassSampler(Sampler): + """Allows you to create stratified sample on unbalanced classes. + + Args: + labels: list of class label for each elem in the dataset + mode: Strategy to balance classes. + Must be one of [downsampling, upsampling] + + Python API examples: + + .. code-block:: python + + import os + from torch import nn, optim + from torch.utils.data import DataLoader + from catalyst import dl + from catalyst.data import ToTensor, BalanceClassSampler + from catalyst.contrib.datasets import MNIST + + train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) + train_labels = train_data.targets.cpu().numpy().tolist() + train_sampler = BalanceClassSampler(train_labels, mode=5000) + valid_data = MNIST(os.getcwd(), train=False) + + loaders = { + "train": DataLoader(train_data, sampler=train_sampler, batch_size=32), + "valid": DataLoader(valid_data, batch_size=32), + } + + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=0.02) + + runner = dl.SupervisedRunner() + # model training + runner.train( + model=model, + criterion=criterion, + optimizer=optimizer, + loaders=loaders, + num_epochs=1, + logdir="./logs", + valid_loader="valid", + valid_metric="loss", + minimize_valid_metric=True, + verbose=True, + ) + """ + + def __init__(self, labels: List[int], mode: Union[str, int] = "downsampling"): + """Sampler initialisation.""" + super().__init__(labels) + + labels = np.array(labels) + samples_per_class = {label: (labels == label).sum() for label in set(labels)} + + self.lbl2idx = { + label: np.arange(len(labels))[labels == label].tolist() + for label in set(labels) + } + + if isinstance(mode, str): + assert mode in ["downsampling", "upsampling"] + + if isinstance(mode, int) or mode == "upsampling": + samples_per_class = ( + mode if isinstance(mode, int) else max(samples_per_class.values()) + ) + else: + samples_per_class = min(samples_per_class.values()) + + self.labels = labels + self.samples_per_class = samples_per_class + self.length = self.samples_per_class * len(set(labels)) + + def __iter__(self) -> Iterator[int]: + """ + Returns: + iterator of indices of stratified sample + """ + indices = [] + for key in sorted(self.lbl2idx): + replace_flag = self.samples_per_class > len(self.lbl2idx[key]) + indices += np.random.choice( + self.lbl2idx[key], self.samples_per_class, replace=replace_flag + ).tolist() + assert len(indices) == self.length + np.random.shuffle(indices) + + return iter(indices) + + def __len__(self) -> int: + """ + Returns: + length of result sample + """ + return self.length + + +class BatchBalanceClassSampler(Sampler): + """ + This kind of sampler can be used for both metric learning and classification task. + + BatchSampler with the given strategy for the C unique classes dataset: + - Selection `num_classes` of C classes for each batch + - Selection `num_samples` instances for each class in the batch + The epoch ends after `num_batches`. + So, the batch sise is `num_classes` * `num_samples`. + + One of the purposes of this sampler is to be used for + forming triplets and pos/neg pairs inside the batch. + To guarante existance of these pairs in the batch, + `num_classes` and `num_samples` should be > 1. (1) + + This type of sampling can be found in the classical paper of Person Re-Id, + where P (`num_classes`) equals 32 and K (`num_samples`) equals 4: + `In Defense of the Triplet Loss for Person Re-Identification`_. + + Args: + labels: list of classes labeles for each elem in the dataset + num_classes: number of classes in a batch, should be > 1 + num_samples: number of instances of each class in a batch, should be > 1 + num_batches: number of batches in epoch + (default = len(labels) // (num_classes * num_samples)) + + .. _In Defense of the Triplet Loss for Person Re-Identification: + https://arxiv.org/abs/1703.07737 + + Python API examples: + + .. code-block:: python + + import os + from torch import nn, optim + from torch.utils.data import DataLoader + from catalyst import dl + from catalyst.data import ToTensor, BatchBalanceClassSampler + from catalyst.contrib.datasets import MNIST + + train_data = MNIST(os.getcwd(), train=True, download=True) + train_labels = train_data.targets.cpu().numpy().tolist() + train_sampler = BatchBalanceClassSampler( + train_labels, num_classes=10, num_samples=4) + valid_data = MNIST(os.getcwd(), train=False) + + loaders = { + "train": DataLoader(train_data, batch_sampler=train_sampler), + "valid": DataLoader(valid_data, batch_size=32), + } + + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=0.02) + + runner = dl.SupervisedRunner() + # model training + runner.train( + model=model, + criterion=criterion, + optimizer=optimizer, + loaders=loaders, + num_epochs=1, + logdir="./logs", + valid_loader="valid", + valid_metric="loss", + minimize_valid_metric=True, + verbose=True, + ) + """ + + def __init__( + self, + labels: Union[List[int], np.ndarray], + num_classes: int, + num_samples: int, + num_batches: int = None, + ): + """Sampler initialisation.""" + super().__init__(labels) + classes = set(labels) + + assert isinstance(num_classes, int) and isinstance(num_samples, int) + assert (1 < num_classes <= len(classes)) and (1 < num_samples) + assert all( + n > 1 for n in Counter(labels).values() + ), "Each class shoud contain at least 2 instances to fit (1)" + + labels = np.array(labels) + self._labels = list(set(labels.tolist())) + self._num_classes = num_classes + self._num_samples = num_samples + self._batch_size = self._num_classes * self._num_samples + self._num_batches = num_batches or len(labels) // self._batch_size + self.lbl2idx = { + label: np.arange(len(labels))[labels == label].tolist() + for label in set(labels) + } + + @property + def batch_size(self) -> int: + """ + Returns: + this value should be used in DataLoader as batch size + """ + return self._batch_size + + @property + def batches_in_epoch(self) -> int: + """ + Returns: + number of batches in an epoch + """ + return self._num_batches + + def __len__(self) -> int: + """ + Returns: + number of samples in an epoch + """ + return self._num_batches # * self._batch_size + + def __iter__(self) -> Iterator[int]: + """ + Returns: + indeces for sampling dataset elems during an epoch + """ + indices = [] + for _ in range(self._num_batches): + batch_indices = [] + classes_for_batch = random.sample(self._labels, self._num_classes) + while self._num_classes != len(set(classes_for_batch)): + classes_for_batch = random.sample(self._labels, self._num_classes) + for cls_id in classes_for_batch: + replace_flag = self._num_samples > len(self.lbl2idx[cls_id]) + batch_indices += np.random.choice( + self.lbl2idx[cls_id], self._num_samples, replace=replace_flag + ).tolist() + indices.append(batch_indices) + return iter(indices) + + +class DynamicBalanceClassSampler(Sampler): + """ + This kind of sampler can be used for classification tasks with significant + class imbalance. + + The idea of this sampler that we start with the original class distribution + and gradually move to uniform class distribution like with downsampling. + + Let's define :math: D_i = #C_i/ #C_min where :math: #C_i is a size of class + i and :math: #C_min is a size of the rarest class, so :math: D_i define + class distribution. Also define :math: g(n_epoch) is a exponential + scheduler. On each epoch current :math: D_i calculated as + :math: current D_i = D_i ^ g(n_epoch), + after this data samples according this distribution. + + Notes: + In the end of the training, epochs will contain only + min_size_class * n_classes examples. So, possible it will not + necessary to do validation on each epoch. For this reason use + ControlFlowCallback. + + Examples: + + >>> import torch + >>> import numpy as np + + >>> from catalyst.data import DynamicBalanceClassSampler + >>> from torch.utils import data + + >>> features = torch.Tensor(np.random.random((200, 100))) + >>> labels = np.random.randint(0, 4, size=(200,)) + >>> sampler = DynamicBalanceClassSampler(labels) + >>> labels = torch.LongTensor(labels) + >>> dataset = data.TensorDataset(features, labels) + >>> loader = data.dataloader.DataLoader(dataset, batch_size=8) + + >>> for batch in loader: + >>> b_features, b_labels = batch + + Sampler was inspired by https://arxiv.org/abs/1901.06783 + """ + + def __init__( + self, + labels: List[Union[int, str]], + exp_lambda: float = 0.9, + start_epoch: int = 0, + max_d: Optional[int] = None, + mode: Union[str, int] = "downsampling", + ignore_warning: bool = False, + ): + """ + Args: + labels: list of labels for each elem in the dataset + exp_lambda: exponent figure for schedule + start_epoch: start epoch number, can be useful for multi-stage + experiments + max_d: if not None, limit on the difference between the most + frequent and the rarest classes, heuristic + mode: number of samples per class in the end of training. Must be + "downsampling" or number. Before change it, make sure that you + understand how does it work + ignore_warning: ignore warning about min class size + """ + assert isinstance(start_epoch, int) + assert 0 < exp_lambda < 1, "exp_lambda must be in (0, 1)" + super().__init__(labels) + self.exp_lambda = exp_lambda + if max_d is None: + max_d = np.inf + self.max_d = max_d + self.epoch = start_epoch + labels = np.array(labels) + samples_per_class = Counter(labels) + self.min_class_size = min(samples_per_class.values()) + + if self.min_class_size < 100 and not ignore_warning: + LOGGER.warning( + f"the smallest class contains only" + f" {self.min_class_size} examples. At the end of" + f" training, epochs will contain only" + f" {self.min_class_size * len(samples_per_class)}" + f" examples" + ) + + self.original_d = { + key: value / self.min_class_size for key, value in samples_per_class.items() + } + self.label2idxes = { + label: np.arange(len(labels))[labels == label].tolist() + for label in set(labels) + } + + if isinstance(mode, int): + self.min_class_size = mode + else: + assert mode == "downsampling" + + self.labels = labels + self._update() + + def _update(self) -> None: + """Update d coefficients.""" + current_d = { + key: min(value ** self._exp_scheduler(), self.max_d) + for key, value in self.original_d.items() + } + samples_per_classes = { + key: int(value * self.min_class_size) for key, value in current_d.items() + } + self.samples_per_classes = samples_per_classes + self.length = np.sum(list(samples_per_classes.values())) + self.epoch += 1 + + def _exp_scheduler(self) -> float: + return self.exp_lambda**self.epoch + + def __iter__(self) -> Iterator[int]: + """ + Returns: + iterator of indices of stratified sample + """ + indices = [] + for key in sorted(self.label2idxes): + samples_per_class = self.samples_per_classes[key] + replace_flag = samples_per_class > len(self.label2idxes[key]) + indices += np.random.choice( + self.label2idxes[key], samples_per_class, replace=replace_flag + ).tolist() + assert len(indices) == self.length + np.random.shuffle(indices) + self._update() + return iter(indices) + + def __len__(self) -> int: + """ + Returns: + length of result sample + """ + return self.length + + +class MiniEpochSampler(Sampler): + """ + Sampler iterates mini epochs from the dataset used by ``mini_epoch_len``. + + Args: + data_len: Size of the dataset + mini_epoch_len: Num samples from the dataset used in one + mini epoch. + drop_last: If ``True``, sampler will drop the last batches + if its size would be less than ``batches_per_epoch`` + shuffle: one of ``"always"``, ``"real_epoch"``, or `None``. + The sampler will shuffle indices + > "per_mini_epoch" - every mini epoch (every ``__iter__`` call) + > "per_epoch" -- every real epoch + > None -- don't shuffle + + Example: + >>> MiniEpochSampler(len(dataset), mini_epoch_len=100) + >>> MiniEpochSampler(len(dataset), mini_epoch_len=100, drop_last=True) + >>> MiniEpochSampler(len(dataset), mini_epoch_len=100, + >>> shuffle="per_epoch") + """ + + def __init__( + self, + data_len: int, + mini_epoch_len: int, + drop_last: bool = False, + shuffle: str = None, + ): + """Sampler initialisation.""" + super().__init__(None) + + self.data_len = int(data_len) + self.mini_epoch_len = int(mini_epoch_len) + + self.steps = int(data_len / self.mini_epoch_len) + self.state_i = 0 + + has_reminder = data_len - self.steps * mini_epoch_len > 0 + if self.steps == 0: + self.divider = 1 + elif has_reminder and not drop_last: + self.divider = self.steps + 1 + else: + self.divider = self.steps + + self._indices = np.arange(self.data_len) + self.indices = self._indices + self.end_pointer = max(self.data_len, self.mini_epoch_len) + + if not (shuffle is None or shuffle in ["per_mini_epoch", "per_epoch"]): + raise ValueError( + "Shuffle must be one of ['per_mini_epoch', 'per_epoch']. " + + f"Got {shuffle}" + ) + self.shuffle_type = shuffle + + def shuffle(self) -> None: + """Shuffle sampler indices.""" + if self.shuffle_type == "per_mini_epoch" or ( + self.shuffle_type == "per_epoch" and self.state_i == 0 + ): + if self.data_len >= self.mini_epoch_len: + self.indices = self._indices + np.random.shuffle(self.indices) + else: + self.indices = np.random.choice( + self._indices, self.mini_epoch_len, replace=True + ) + + def __iter__(self) -> Iterator[int]: + """Iterate over sampler. + + Returns: + python iterator + """ + self.state_i = self.state_i % self.divider + self.shuffle() + + start = self.state_i * self.mini_epoch_len + stop = ( + self.end_pointer + if (self.state_i == self.steps) + else (self.state_i + 1) * self.mini_epoch_len + ) + indices = self.indices[start:stop].tolist() + + self.state_i += 1 + return iter(indices) + + def __len__(self) -> int: + """ + Returns: + int: length of the mini-epoch + """ + return self.mini_epoch_len + + +class DistributedSamplerWrapper(DistributedSampler): + """ + Wrapper over `Sampler` for distributed training. + Allows you to use any sampler in distributed mode. + + It is especially useful in conjunction with + `torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSamplerWrapper instance as a DataLoader + sampler, and load a subset of subsampled data of the original dataset + that is exclusive to it. + + .. note:: + Sampler is assumed to be of constant size. + """ + + def __init__( + self, + sampler, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + ): + """ + + Args: + sampler: Sampler used for subsampling + num_replicas (int, optional): Number of processes participating in + distributed training + rank (int, optional): Rank of the current process + within ``num_replicas`` + shuffle (bool, optional): If true (default), + sampler will shuffle the indices + """ + super(DistributedSamplerWrapper, self).__init__( + DatasetFromSampler(sampler), + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + ) + self.sampler = sampler + + def __iter__(self) -> Iterator[int]: + """Iterate over sampler. + + Returns: + python iterator + """ + self.dataset = DatasetFromSampler(self.sampler) + indexes_of_indexes = super().__iter__() + subsampler_indexes = self.dataset + return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) + + +__all__ = [ + "BalanceClassSampler", + "BatchBalanceClassSampler", + "DistributedSamplerWrapper", + "DynamicBalanceClassSampler", + "MiniEpochSampler", +] diff --git a/audioldm2/utilities/tools.py b/audioldm2/utilities/tools.py new file mode 100755 index 0000000000000000000000000000000000000000..a647a272cdf076b2ae9785bc83724ebd7a897642 --- /dev/null +++ b/audioldm2/utilities/tools.py @@ -0,0 +1,541 @@ +# Author: Haohe Liu +# Email: haoheliu@gmail.com +# Date: 11 Feb 2023 + +import os +import json + +import torch +import torch.nn.functional as F +import numpy as np +import matplotlib +from scipy.io import wavfile +from matplotlib import pyplot as plt + + +matplotlib.use("Agg") + +import hashlib +import os + +import requests +from tqdm import tqdm + +URL_MAP = { + "vggishish_lpaps": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt", + "vggishish_mean_std_melspec_10s_22050hz": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt", + "melception": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt", +} + +CKPT_MAP = { + "vggishish_lpaps": "vggishish16.pt", + "vggishish_mean_std_melspec_10s_22050hz": "train_means_stds_melspec_10s_22050hz.txt", + "melception": "melception-21-05-10T09-28-40.pt", +} + +MD5_MAP = { + "vggishish_lpaps": "197040c524a07ccacf7715d7080a80bd", + "vggishish_mean_std_melspec_10s_22050hz": "f449c6fd0e248936c16f6d22492bb625", + "melception": "a71a41041e945b457c7d3d814bbcf72d", +} + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def load_json(fname): + with open(fname, "r") as f: + data = json.load(f) + return data + + +def read_json(dataset_json_file): + with open(dataset_json_file, "r") as fp: + data_json = json.load(fp) + return data_json["data"] + + +def copy_test_subset_data(metadata, testset_copy_target_path): + # metadata = read_json(testset_metadata) + os.makedirs(testset_copy_target_path, exist_ok=True) + if len(os.listdir(testset_copy_target_path)) == len(metadata): + return + else: + # delete files in folder testset_copy_target_path + for file in os.listdir(testset_copy_target_path): + try: + os.remove(os.path.join(testset_copy_target_path, file)) + except Exception as e: + print(e) + + print("Copying test subset data to {}".format(testset_copy_target_path)) + for each in tqdm(metadata): + cmd = "cp {} {}".format(each["wav"], os.path.join(testset_copy_target_path)) + os.system(cmd) + + +def listdir_nohidden(path): + for f in os.listdir(path): + if not f.startswith("."): + yield f + + +def get_restore_step(path): + checkpoints = os.listdir(path) + if os.path.exists(os.path.join(path, "final.ckpt")): + return "final.ckpt", 0 + elif not os.path.exists(os.path.join(path, "last.ckpt")): + steps = [int(x.split(".ckpt")[0].split("step=")[1]) for x in checkpoints] + return checkpoints[np.argmax(steps)], np.max(steps) + else: + steps = [] + for x in checkpoints: + if "last" in x: + if "-v" not in x: + fname = "last.ckpt" + else: + this_version = int(x.split(".ckpt")[0].split("-v")[1]) + steps.append(this_version) + if len(steps) == 0 or this_version > np.max(steps): + fname = "last-v%s.ckpt" % this_version + return fname, 0 + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class KeyNotFoundError(Exception): + def __init__(self, cause, keys=None, visited=None): + self.cause = cause + self.keys = keys + self.visited = visited + messages = list() + if keys is not None: + messages.append("Key not found: {}".format(keys)) + if visited is not None: + messages.append("Visited: {}".format(visited)) + messages.append("Cause:\n{}".format(cause)) + message = "\n".join(messages) + super().__init__(message) + + +def retrieve( + list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False +): + """Given a nested list or dict return the desired value at key expanding + callable nodes if necessary and :attr:`expand` is ``True``. The expansion + is done in-place. + + Parameters + ---------- + list_or_dict : list or dict + Possibly nested list or dictionary. + key : str + key/to/value, path like string describing all keys necessary to + consider to get to the desired value. List indices can also be + passed here. + splitval : str + String that defines the delimiter between keys of the + different depth levels in `key`. + default : obj + Value returned if :attr:`key` is not found. + expand : bool + Whether to expand callable nodes on the path or not. + + Returns + ------- + The desired value or if :attr:`default` is not ``None`` and the + :attr:`key` is not found returns ``default``. + + Raises + ------ + Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is + ``None``. + """ + + keys = key.split(splitval) + + success = True + try: + visited = [] + parent = None + last_key = None + for key in keys: + if callable(list_or_dict): + if not expand: + raise KeyNotFoundError( + ValueError( + "Trying to get past callable node with expand=False." + ), + keys=keys, + visited=visited, + ) + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + + last_key = key + parent = list_or_dict + + try: + if isinstance(list_or_dict, dict): + list_or_dict = list_or_dict[key] + else: + list_or_dict = list_or_dict[int(key)] + except (KeyError, IndexError, ValueError) as e: + raise KeyNotFoundError(e, keys=keys, visited=visited) + + visited += [key] + # final expansion of retrieved value + if expand and callable(list_or_dict): + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + except KeyNotFoundError as e: + if default is None: + raise e + else: + list_or_dict = default + success = False + + if not pass_success: + return list_or_dict + else: + return list_or_dict, success + + +def to_device(data, device): + if len(data) == 12: + ( + ids, + raw_texts, + speakers, + texts, + src_lens, + max_src_len, + mels, + mel_lens, + max_mel_len, + pitches, + energies, + durations, + ) = data + + speakers = torch.from_numpy(speakers).long().to(device) + texts = torch.from_numpy(texts).long().to(device) + src_lens = torch.from_numpy(src_lens).to(device) + mels = torch.from_numpy(mels).float().to(device) + mel_lens = torch.from_numpy(mel_lens).to(device) + pitches = torch.from_numpy(pitches).float().to(device) + energies = torch.from_numpy(energies).to(device) + durations = torch.from_numpy(durations).long().to(device) + + return ( + ids, + raw_texts, + speakers, + texts, + src_lens, + max_src_len, + mels, + mel_lens, + max_mel_len, + pitches, + energies, + durations, + ) + + if len(data) == 6: + (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data + + speakers = torch.from_numpy(speakers).long().to(device) + texts = torch.from_numpy(texts).long().to(device) + src_lens = torch.from_numpy(src_lens).to(device) + + return (ids, raw_texts, speakers, texts, src_lens, max_src_len) + + +def log(logger, step=None, fig=None, audio=None, sampling_rate=22050, tag=""): + # if losses is not None: + # logger.add_scalar("Loss/total_loss", losses[0], step) + # logger.add_scalar("Loss/mel_loss", losses[1], step) + # logger.add_scalar("Loss/mel_postnet_loss", losses[2], step) + # logger.add_scalar("Loss/pitch_loss", losses[3], step) + # logger.add_scalar("Loss/energy_loss", losses[4], step) + # logger.add_scalar("Loss/duration_loss", losses[5], step) + # if(len(losses) > 6): + # logger.add_scalar("Loss/disc_loss", losses[6], step) + # logger.add_scalar("Loss/fmap_loss", losses[7], step) + # logger.add_scalar("Loss/r_loss", losses[8], step) + # logger.add_scalar("Loss/g_loss", losses[9], step) + # logger.add_scalar("Loss/gen_loss", losses[10], step) + # logger.add_scalar("Loss/diff_loss", losses[11], step) + + if fig is not None: + logger.add_figure(tag, fig) + + if audio is not None: + audio = audio / (max(abs(audio)) * 1.1) + logger.add_audio( + tag, + audio, + sample_rate=sampling_rate, + ) + + +def get_mask_from_lengths(lengths, max_len=None): + batch_size = lengths.shape[0] + if max_len is None: + max_len = torch.max(lengths).item() + + ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) + mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) + + return mask + + +def expand(values, durations): + out = list() + for value, d in zip(values, durations): + out += [value] * max(0, int(d)) + return np.array(out) + + +def synth_one_sample_val( + targets, predictions, vocoder, model_config, preprocess_config +): + index = np.random.choice(list(np.arange(targets[6].size(0)))) + + basename = targets[0][index] + src_len = predictions[8][index].item() + mel_len = predictions[9][index].item() + mel_target = targets[6][index, :mel_len].detach().transpose(0, 1) + + mel_prediction = predictions[0][index, :mel_len].detach().transpose(0, 1) + postnet_mel_prediction = predictions[1][index, :mel_len].detach().transpose(0, 1) + duration = targets[11][index, :src_len].detach().cpu().numpy() + + if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": + pitch = predictions[2][index, :src_len].detach().cpu().numpy() + pitch = expand(pitch, duration) + else: + pitch = predictions[2][index, :mel_len].detach().cpu().numpy() + + if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": + energy = predictions[3][index, :src_len].detach().cpu().numpy() + energy = expand(energy, duration) + else: + energy = predictions[3][index, :mel_len].detach().cpu().numpy() + + with open( + os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") + ) as f: + stats = json.load(f) + stats = stats["pitch"] + stats["energy"][:2] + + # from datetime import datetime + # now = datetime.now() + # current_time = now.strftime("%D:%H:%M:%S") + # np.save(("mel_pred_%s.npy" % current_time).replace("/","-"), mel_prediction.cpu().numpy()) + # np.save(("postnet_mel_prediction_%s.npy" % current_time).replace("/","-"), postnet_mel_prediction.cpu().numpy()) + # np.save(("mel_target_%s.npy" % current_time).replace("/","-"), mel_target.cpu().numpy()) + + fig = plot_mel( + [ + (mel_prediction.cpu().numpy(), pitch, energy), + (postnet_mel_prediction.cpu().numpy(), pitch, energy), + (mel_target.cpu().numpy(), pitch, energy), + ], + stats, + [ + "Raw mel spectrogram prediction", + "Postnet mel prediction", + "Ground-Truth Spectrogram", + ], + ) + + if vocoder is not None: + from .model import vocoder_infer + + wav_reconstruction = vocoder_infer( + mel_target.unsqueeze(0), + vocoder, + model_config, + preprocess_config, + )[0] + wav_prediction = vocoder_infer( + postnet_mel_prediction.unsqueeze(0), + vocoder, + model_config, + preprocess_config, + )[0] + else: + wav_reconstruction = wav_prediction = None + + return fig, wav_reconstruction, wav_prediction, basename + + +def synth_one_sample(mel_input, mel_prediction, labels, vocoder): + if vocoder is not None: + from .model import vocoder_infer + + wav_reconstruction = vocoder_infer( + mel_input.permute(0, 2, 1), + vocoder, + ) + wav_prediction = vocoder_infer( + mel_prediction.permute(0, 2, 1), + vocoder, + ) + else: + wav_reconstruction = wav_prediction = None + + return wav_reconstruction, wav_prediction + + +def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path): + # (diff_output, diff_loss, latent_loss) = diffusion + + basenames = targets[0] + + for i in range(len(predictions[1])): + basename = basenames[i] + src_len = predictions[8][i].item() + mel_len = predictions[9][i].item() + mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1) + # diff_output = diff_output[i, :mel_len].detach().transpose(0, 1) + # duration = predictions[5][i, :src_len].detach().cpu().numpy() + if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": + pitch = predictions[2][i, :src_len].detach().cpu().numpy() + # pitch = expand(pitch, duration) + else: + pitch = predictions[2][i, :mel_len].detach().cpu().numpy() + if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": + energy = predictions[3][i, :src_len].detach().cpu().numpy() + # energy = expand(energy, duration) + else: + energy = predictions[3][i, :mel_len].detach().cpu().numpy() + # import ipdb; ipdb.set_trace() + with open( + os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") + ) as f: + stats = json.load(f) + stats = stats["pitch"] + stats["energy"][:2] + + fig = plot_mel( + [ + (mel_prediction.cpu().numpy(), pitch, energy), + ], + stats, + ["Synthetized Spectrogram by PostNet"], + ) + # np.save("{}_postnet.npy".format(basename), mel_prediction.cpu().numpy()) + plt.savefig(os.path.join(path, "{}_postnet_2.png".format(basename))) + plt.close() + + from .model import vocoder_infer + + mel_predictions = predictions[1].transpose(1, 2) + lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"] + wav_predictions = vocoder_infer( + mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths + ) + + sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] + for wav, basename in zip(wav_predictions, basenames): + wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav) + + +def plot_mel(data, titles=None): + fig, axes = plt.subplots(len(data), 1, squeeze=False) + if titles is None: + titles = [None for i in range(len(data))] + + for i in range(len(data)): + mel = data[i] + axes[i][0].imshow(mel, origin="lower", aspect="auto") + axes[i][0].set_aspect(2.5, adjustable="box") + axes[i][0].set_ylim(0, mel.shape[0]) + axes[i][0].set_title(titles[i], fontsize="medium") + axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) + axes[i][0].set_anchor("W") + + return fig + + +def pad_1D(inputs, PAD=0): + def pad_data(x, length, PAD): + x_padded = np.pad( + x, (0, length - x.shape[0]), mode="constant", constant_values=PAD + ) + return x_padded + + max_len = max((len(x) for x in inputs)) + padded = np.stack([pad_data(x, max_len, PAD) for x in inputs]) + + return padded + + +def pad_2D(inputs, maxlen=None): + def pad(x, max_len): + PAD = 0 + if np.shape(x)[0] > max_len: + raise ValueError("not max_len") + + s = np.shape(x)[1] + x_padded = np.pad( + x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD + ) + return x_padded[:, :s] + + if maxlen: + output = np.stack([pad(x, maxlen) for x in inputs]) + else: + max_len = max(np.shape(x)[0] for x in inputs) + output = np.stack([pad(x, max_len) for x in inputs]) + + return output + + +def pad(input_ele, mel_max_length=None): + if mel_max_length: + max_len = mel_max_length + else: + max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) + + out_list = list() + for i, batch in enumerate(input_ele): + if len(batch.shape) == 1: + one_batch_padded = F.pad( + batch, (0, max_len - batch.size(0)), "constant", 0.0 + ) + elif len(batch.shape) == 2: + one_batch_padded = F.pad( + batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 + ) + out_list.append(one_batch_padded) + out_padded = torch.stack(out_list) + return out_padded diff --git a/audioldm2/utils.py b/audioldm2/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..c098e25b99cf78b0c0befc71fb7ba7688e79c899 --- /dev/null +++ b/audioldm2/utils.py @@ -0,0 +1,352 @@ +import contextlib +import importlib +from huggingface_hub import hf_hub_download + +from inspect import isfunction +import os +import soundfile as sf +import time +import wave + +import progressbar + +CACHE_DIR = os.getenv( + "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm2") +) + +def read_list(fname): + result = [] + with open(fname, "r", encoding="utf-8") as f: + for each in f.readlines(): + each = each.strip('\n') + result.append(each) + return result + +def get_duration(fname): + with contextlib.closing(wave.open(fname, "r")) as f: + frames = f.getnframes() + rate = f.getframerate() + return frames / float(rate) + + +def get_bit_depth(fname): + with contextlib.closing(wave.open(fname, "r")) as f: + bit_depth = f.getsampwidth() * 8 + return bit_depth + + +def get_time(): + t = time.localtime() + return time.strftime("%d_%m_%Y_%H_%M_%S", t) + + +def seed_everything(seed): + import random, os + import numpy as np + import torch + + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + + +def save_wave(waveform, savepath, name="outwav"): + if type(name) is not list: + name = [name] * waveform.shape[0] + + for i in range(waveform.shape[0]): + path = os.path.join( + savepath, + "%s_%s.wav" + % ( + os.path.basename(name[i]) + if (not ".wav" in name[i]) + else os.path.basename(name[i]).split(".")[0], + i, + ), + ) + print("Save audio to %s" % path) + sf.write(path, waveform[i, 0], samplerate=16000) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + try: + return get_obj_from_str(config["target"])(**config.get("params", dict())) + except: + import ipdb + + ipdb.set_trace() + + +def default_audioldm_config(model_name="audioldm2-full"): + basic_config = { + "metadata_root": "/mnt/bn/lqhaoheliu/metadata/processed/dataset_root.json", + "log_directory": "./log/audiomae_pred", + "precision": "high", + "data": { + "train": [ + "audiocaps", + "audioset", + "wavcaps", + "audiostock_music_250k", + "free_to_use_sounds", + "epidemic_sound_effects", + "vggsound", + "million_song_dataset", + ], + "val": "audiocaps", + "test": "audiocaps", + "class_label_indices": "audioset", + "dataloader_add_ons": [ + "extract_kaldi_fbank_feature", + "extract_vits_phoneme_and_flant5_text", + "waveform_rs_48k", + ], + }, + "variables": { + "sampling_rate": 16000, + "mel_bins": 64, + "latent_embed_dim": 8, + "latent_t_size": 256, + "latent_f_size": 16, + "in_channels": 8, + "optimize_ddpm_parameter": True, + "warmup_steps": 5000, + }, + "step": { + "validation_every_n_epochs": 1, + "save_checkpoint_every_n_steps": 5000, + "limit_val_batches": 10, + "max_steps": 1500000, + "save_top_k": 2, + }, + "preprocessing": { + "audio": { + "sampling_rate": 16000, + "max_wav_value": 32768, + "duration": 10.24, + }, + "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024}, + "mel": {"n_mel_channels": 64, "mel_fmin": 0, "mel_fmax": 8000}, + }, + "augmentation": {"mixup": 0}, + "model": { + "target": "audioldm2.latent_diffusion.models.ddpm.LatentDiffusion", + "params": { + "first_stage_config": { + "base_learning_rate": 0.000008, + "target": "audioldm2.latent_encoder.autoencoder.AutoencoderKL", + "params": { + "sampling_rate": 16000, + "batchsize": 4, + "monitor": "val/rec_loss", + "image_key": "fbank", + "subband": 1, + "embed_dim": 8, + "time_shuffle": 1, + "lossconfig": { + "target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator", + "params": { + "disc_start": 50001, + "kl_weight": 1000, + "disc_weight": 0.5, + "disc_in_channels": 1, + }, + }, + "ddconfig": { + "double_z": True, + "mel_bins": 64, + "z_channels": 8, + "resolution": 256, + "downsample_time": False, + "in_channels": 1, + "out_ch": 1, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0, + }, + }, + }, + "base_learning_rate": 0.0001, + "warmup_steps": 5000, + "optimize_ddpm_parameter": True, + "sampling_rate": 16000, + "batchsize": 16, + "linear_start": 0.0015, + "linear_end": 0.0195, + "num_timesteps_cond": 1, + "log_every_t": 200, + "timesteps": 1000, + "unconditional_prob_cfg": 0.1, + "parameterization": "eps", + "first_stage_key": "fbank", + "latent_t_size": 256, + "latent_f_size": 16, + "channels": 8, + "monitor": "val/loss_simple_ema", + "scale_by_std": True, + "unet_config": { + "target": "audioldm2.latent_diffusion.modules.diffusionmodules.openaimodel.UNetModel", + "params": { + "image_size": 64, + "context_dim": [768, 1024], + "in_channels": 8, + "out_channels": 8, + "model_channels": 128, + "attention_resolutions": [8, 4, 2], + "num_res_blocks": 2, + "channel_mult": [1, 2, 3, 5], + "num_head_channels": 32, + "use_spatial_transformer": True, + "transformer_depth": 1, + }, + }, + "evaluation_params": { + "unconditional_guidance_scale": 3.5, + "ddim_sampling_steps": 200, + "n_candidates_per_samples": 3, + }, + "cond_stage_config": { + "crossattn_audiomae_generated": { + "cond_stage_key": "all", + "conditioning_key": "crossattn", + "target": "audioldm2.latent_diffusion.modules.encoders.modules.SequenceGenAudioMAECond", + "params": { + "always_output_audiomae_gt": False, + "learnable": True, + "device": "cuda", + "use_gt_mae_output": True, + "use_gt_mae_prob": 0.25, + "base_learning_rate": 0.0002, + "sequence_gen_length": 8, + "use_warmup": True, + "sequence_input_key": [ + "film_clap_cond1", + "crossattn_flan_t5", + ], + "sequence_input_embed_dim": [512, 1024], + "batchsize": 16, + "cond_stage_config": { + "film_clap_cond1": { + "cond_stage_key": "text", + "conditioning_key": "film", + "target": "audioldm2.latent_diffusion.modules.encoders.modules.CLAPAudioEmbeddingClassifierFreev2", + "params": { + "sampling_rate": 48000, + "embed_mode": "text", + "amodel": "HTSAT-base", + }, + }, + "crossattn_flan_t5": { + "cond_stage_key": "text", + "conditioning_key": "crossattn", + "target": "audioldm2.latent_diffusion.modules.encoders.modules.FlanT5HiddenState", + }, + "crossattn_audiomae_pooled": { + "cond_stage_key": "ta_kaldi_fbank", + "conditioning_key": "crossattn", + "target": "audioldm2.latent_diffusion.modules.encoders.modules.AudioMAEConditionCTPoolRand", + "params": { + "regularization": False, + "no_audiomae_mask": True, + "time_pooling_factors": [8], + "freq_pooling_factors": [8], + "eval_time_pooling": 8, + "eval_freq_pooling": 8, + "mask_ratio": 0, + }, + }, + }, + }, + }, + "crossattn_flan_t5": { + "cond_stage_key": "text", + "conditioning_key": "crossattn", + "target": "audioldm2.latent_diffusion.modules.encoders.modules.FlanT5HiddenState", + }, + }, + }, + }, + } + return basic_config + + +def get_metadata(): + return { + "audioldm2-full": { + "path": os.path.join( + CACHE_DIR, + "audioldm2-full.pth", + ), + "url": "https://huggingface.co/haoheliu/audioldm2-full/resolve/main/audioldm2-full.pth", + }, + } + + +class MyProgressBar: + def __init__(self): + self.pbar = None + + def __call__(self, block_num, block_size, total_size): + if not self.pbar: + self.pbar = progressbar.ProgressBar(maxval=total_size) + self.pbar.start() + + downloaded = block_num * block_size + if downloaded < total_size: + self.pbar.update(downloaded) + else: + self.pbar.finish() + + +def download_checkpoint(checkpoint_name="audioldm2-full"): + meta = get_metadata() + if checkpoint_name not in meta.keys(): + print( + "The model name you provided is not supported. Please use one of the following: ", + meta.keys(), + ) + + model_id = "haoheliu/%s" % checkpoint_name + hf_hub_download( + repo_id=model_id, + filename=os.path.basename(meta[checkpoint_name]["path"]), + local_dir=os.path.dirname(meta[checkpoint_name]["path"]), + ) diff --git a/batch.lst b/batch.lst new file mode 100644 index 0000000000000000000000000000000000000000..c52c7a52523775ad729bfbb350f9cd70ddfbf3e4 --- /dev/null +++ b/batch.lst @@ -0,0 +1,4 @@ +A forest of wind chimes singing a soothing melody in the breeze. +A violin playing a heartfelt melody. +A saxophone playing a soulful melody. +Musical constellations twinkling in the night sky, forming a cosmic melody. \ No newline at end of file diff --git a/bg.png b/bg.png new file mode 100644 index 0000000000000000000000000000000000000000..2811a3593a7492c5af5754ab6949a6e60f2635bd Binary files /dev/null and b/bg.png differ diff --git a/bin/audioldm2 b/bin/audioldm2 new file mode 100755 index 0000000000000000000000000000000000000000..2ff95674ad63326a4b0d7f7b633c2f87949502f3 --- /dev/null +++ b/bin/audioldm2 @@ -0,0 +1,131 @@ +#!/usr/bin/python3 +import os +import torch +import logging +from audioldm2 import text_to_audio, build_model, save_wave, get_time, read_list +import argparse + +os.environ["TOKENIZERS_PARALLELISM"] = "true" +matplotlib_logger = logging.getLogger('matplotlib') +matplotlib_logger.setLevel(logging.WARNING) + + +CACHE_DIR = os.getenv( + "AUDIOLDM_CACHE_DIR", + os.path.join(os.path.expanduser("~"), ".cache/audioldm2")) + +parser = argparse.ArgumentParser() + +parser.add_argument( + "-t", + "--text", + type=str, + required=False, + default="", + help="Text prompt to the model for audio generation", +) + +parser.add_argument( + "-tl", + "--text_list", + type=str, + required=False, + default="", + help="A file that contains text prompt to the model for audio generation", +) + +parser.add_argument( + "-s", + "--save_path", + type=str, + required=False, + help="The path to save model output", + default="./output", +) + +parser.add_argument( + "--model_name", + type=str, + required=False, + help="The checkpoint you gonna use", + default="audioldm2-full", + choices=["audioldm2-full"] +) + +parser.add_argument( + "-b", + "--batchsize", + type=int, + required=False, + default=1, + help="Generate how many samples at the same time", +) + +parser.add_argument( + "--ddim_steps", + type=int, + required=False, + default=200, + help="The sampling step for DDIM", +) + +parser.add_argument( + "-gs", + "--guidance_scale", + type=float, + required=False, + default=3.5, + help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)", +) + +parser.add_argument( + "-n", + "--n_candidate_gen_per_text", + type=int, + required=False, + default=3, + help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation", +) + +parser.add_argument( + "--seed", + type=int, + required=False, + default=0, + help="Change this value (any integer number) will lead to a different generation result.", +) + +args = parser.parse_args() + +torch.set_float32_matmul_precision("high") + +save_path = os.path.join(args.save_path, get_time()) + +text = args.text +random_seed = args.seed +duration = 10 +guidance_scale = args.guidance_scale +n_candidate_gen_per_text = args.n_candidate_gen_per_text + +os.makedirs(save_path, exist_ok=True) +audioldm2 = build_model(model_name=args.model_name) + +if(args.text_list): + print("Generate audio based on the text prompts in %s" % args.text_list) + prompt_todo = read_list(args.text_list) +else: + prompt_todo = [text] + +for text in prompt_todo: + waveform = text_to_audio( + audioldm2, + text, + seed=random_seed, + duration=duration, + guidance_scale=guidance_scale, + ddim_steps=args.ddim_steps, + n_candidate_gen_per_text=n_candidate_gen_per_text, + batchsize=args.batchsize, + ) + + save_wave(waveform, save_path, name=text) diff --git a/bin/audioldm2.cmd b/bin/audioldm2.cmd new file mode 100755 index 0000000000000000000000000000000000000000..c164fbfb6a194858b6d9019c8e29df3e57b3172a --- /dev/null +++ b/bin/audioldm2.cmd @@ -0,0 +1,2 @@ +@echo OFF +python -m audioldm2 %* \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ea2a8af90d37d8a1994cc87a7731b03ea065cd8f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +git+https://github.com/huggingface/diffusers.git +git+https://github.com/huggingface/transformers.git +--extra-index-url https://download.pytorch.org/whl/cu113 +torch >= 2.0 +huggingface_hub +transformers==4.30.2 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..69f32a48bc29a8e293efbbf3fb05080e940de108 --- /dev/null +++ b/setup.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +# python3 setup.py sdist bdist_wheel +""" +@File : setup.py.py +@Contact : haoheliu@gmail.com +@License : (C)Copyright 2020-2100 + +@Modify Time @Author @Version @Desciption +------------ ------- -------- ----------- +9/6/21 5:16 PM Haohe Liu 1.0 None +""" + +# !/usr/bin/env python +# -*- coding: utf-8 -*- + +# Note: To use the 'upload' functionality of this file, you must: +# $ pipenv install twine --dev + +import io +import os +import sys +from shutil import rmtree + +from setuptools import find_packages, setup, Command + +# Package meta-data. +NAME = "audioldm2" +DESCRIPTION = "This package is written for text-to-audio/music generation." +URL = "https://github.com/haoheliu/audioldm2" +EMAIL = "haoheliu@gmail.com" +AUTHOR = "Haohe Liu" +REQUIRES_PYTHON = ">=3.7.0" +VERSION = "0.0.2" + +# What packages are required for this module to be executed? +REQUIRED = [ + "torch>=1.13.0", + "torchaudio>=0.13.0", + "torchvision>=0.14.0", + "tqdm", + "gradio", + "pyyaml", + "einops", + "chardet", + "numpy<=1.23.5", + "soundfile", + "librosa==0.9.2", + "scipy", + "pandas", + "torchlibrosa==0.0.9", + "transformers", + "progressbar", + "ftfy", +] + +# What packages are optional? +EXTRAS = {} + +# The rest you shouldn't have to touch too much :) +# ------------------------------------------------ +# Except, perhaps the License and Trove Classifiers! +# If you do change the License, remember to change the Trove Classifier for that! + +here = os.path.abspath(os.path.dirname(__file__)) + +# Import the README and use it as the long-description. +# Note: this will only work if 'README.md' is present in your MANIFEST.in file! +try: + with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: + long_description = "\n" + f.read() +except FileNotFoundError: + long_description = DESCRIPTION + +# Load the package's __version__.py module as a dictionary. +about = {} +if not VERSION: + project_slug = NAME.lower().replace("-", "_").replace(" ", "_") + with open(os.path.join(here, project_slug, "__version__.py")) as f: + exec(f.read(), about) +else: + about["__version__"] = VERSION + + +class UploadCommand(Command): + """Support setup.py upload.""" + + description = "Build and publish the package." + user_options = [] + + @staticmethod + def status(s): + """Prints things in bold.""" + print("\033[1m{0}\033[0m".format(s)) + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + try: + self.status("Removing previous builds…") + rmtree(os.path.join(here, "dist")) + except OSError: + pass + + self.status("Building Source and Wheel (universal) distribution…") + os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) + + self.status("Uploading the package to PyPI via Twine…") + os.system("twine upload dist/*") + + self.status("Pushing git tags…") + os.system("git tag v{0}".format(about["__version__"])) + os.system("git push --tags") + + sys.exit() + + +# Where the magic happens: +setup( + name=NAME, + version=about["__version__"], + description=DESCRIPTION, + long_description=long_description, + long_description_content_type="text/markdown", + author=AUTHOR, + author_email=EMAIL, + python_requires=REQUIRES_PYTHON, + url=URL, + # packages=find_packages(exclude=[]), + # If your package is a single module, use this instead of 'packages': + # entry_points={ + # 'console_scripts': ['mycli=mymodule:cli'], + # }, + install_requires=REQUIRED, + extras_require=EXTRAS, + packages=find_packages(), + include_package_data=True, + license="MIT", + classifiers=[ + # Trove classifiers + # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers + "License :: OSI Approved :: MIT License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + ], + # $ setup.py publish support. + cmdclass={ + "upload": UploadCommand, + }, + scripts=["bin/audioldm2.cmd", "bin/audioldm2"], +) diff --git a/share_btn.py b/share_btn.py new file mode 100644 index 0000000000000000000000000000000000000000..a0378607680fa5468e9034d230f546f5f0913ae0 --- /dev/null +++ b/share_btn.py @@ -0,0 +1,74 @@ +community_icon_html = """""" + +loading_icon_html = """""" + +share_js = """async () => { + async function uploadFile(file){ + const UPLOAD_URL = 'https://huggingface.co/uploads'; + const response = await fetch(UPLOAD_URL, { + method: 'POST', + headers: { + 'Content-Type': file.type, + 'X-Requested-With': 'XMLHttpRequest', + }, + body: file, /// <- File inherits from Blob + }); + const url = await response.text(); + return url; + } + async function getInputVideoFile(videoEl){ + const res = await fetch(videoEl.src); + const blob = await res.blob(); + const videoId = Date.now() % 200; + const fileName = `sd-perception-${{videoId}}.mp4`; + return new File([blob], fileName, { type: 'video/mp4' }); + } + + async function audioToBase64(audioFile) { + return new Promise((resolve, reject) => { + let reader = new FileReader(); + reader.readAsDataURL(audioFile); + reader.onload = () => resolve(reader.result); + reader.onerror = error => reject(error); + + }); + } + const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app'); + const inputPromptEl = gradioEl.querySelector('#prompt-in input').value; + const outputVideoEl = gradioEl.querySelector('#output-video video'); + + let titleTxt = `Text-to-Audio: ${inputPromptEl}`; + + const shareBtnEl = gradioEl.querySelector('#share-btn'); + const shareIconEl = gradioEl.querySelector('#share-btn-share-icon'); + const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon'); + if(!outputVideoEl){ + return; + }; + shareBtnEl.style.pointerEvents = 'none'; + shareIconEl.style.display = 'none'; + loadingIconEl.style.removeProperty('display'); + const outputVideo = await getInputVideoFile(outputVideoEl); + const urlOutputVideo = await uploadFile(outputVideo); + + const descriptionMd = ` +##### ${inputPromptEl} + +${urlOutputVideo} +`; + const params = new URLSearchParams({ + title: titleTxt, + description: descriptionMd, + }); + const paramsStr = params.toString(); + window.open(`https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation/discussions/new?${paramsStr}`, '_blank'); + shareBtnEl.style.removeProperty('pointer-events'); + shareIconEl.style.removeProperty('display'); + loadingIconEl.style.display = 'none'; +}""" diff --git a/tests/code_coverage.py b/tests/code_coverage.py new file mode 100644 index 0000000000000000000000000000000000000000..deb035e9fedffacd8bf3a9c37d5566fa8fd4e819 --- /dev/null +++ b/tests/code_coverage.py @@ -0,0 +1,3 @@ +import os + +os.system('python3 bin/audioldm2 -t "A toilet flushing and water trickling"') diff --git a/tests/code_coverage.sh b/tests/code_coverage.sh new file mode 100644 index 0000000000000000000000000000000000000000..0a5920c645c262c80436ff586e7ad4825e9e5622 --- /dev/null +++ b/tests/code_coverage.sh @@ -0,0 +1 @@ +pytest --cov=src tests/* \ No newline at end of file