diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..24aa54d81a0af0eeffa744a6c81e8a79f5cb76a3
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,9 @@
+*.pyc
+__pycache__
+test.py
+flagged
+output
+gradio_cached*
+dist*
+*egg-info
+build*
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..d56ccfbb7496de8592c9745d0d9a5e390af75fd6
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,437 @@
+Attribution-NonCommercial-ShareAlike 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.cp
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More_considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
+Public License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial-ShareAlike 4.0 International Public License
+("Public License"). To the extent this Public License may be
+interpreted as a contract, You are granted the Licensed Rights in
+consideration of Your acceptance of these terms and conditions, and the
+Licensor grants You such rights in consideration of benefits the
+Licensor receives from making the Licensed Material available under
+these terms and conditions.
+
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. BY-NC-SA Compatible License means a license listed at
+ creativecommons.org/compatiblelicenses, approved by Creative
+ Commons as essentially the equivalent of this Public License.
+
+ d. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+
+ e. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ g. License Elements means the license attributes listed in the name
+ of a Creative Commons Public License. The License Elements of this
+ Public License are Attribution, NonCommercial, and ShareAlike.
+
+ h. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ i. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ j. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ k. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ l. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ m. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ n. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. Additional offer from the Licensor -- Adapted Material.
+ Every recipient of Adapted Material from You
+ automatically receives an offer from the Licensor to
+ exercise the Licensed Rights in the Adapted Material
+ under the conditions of the Adapter's License You apply.
+
+ c. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ b. ShareAlike.
+
+ In addition to the conditions in Section 3(a), if You Share
+ Adapted Material You produce, the following conditions also apply.
+
+ 1. The Adapter's License You apply must be a Creative Commons
+ license with the same License Elements, this version or
+ later, or a BY-NC-SA Compatible License.
+
+ 2. You must include the text of, or the URI or hyperlink to, the
+ Adapter's License You apply. You may satisfy this condition
+ in any reasonable manner based on the medium, means, and
+ context in which You Share Adapted Material.
+
+ 3. You may not offer or impose any additional or different terms
+ or conditions on, or apply any Effective Technological
+ Measures to, Adapted Material that restrict exercise of the
+ rights granted under the Adapter's License You apply.
+
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material,
+ including for purposes of Section 3(b); and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
\ No newline at end of file
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..5dcbc139b80a520e73d210bb3236cb0a25a31129
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,2 @@
+include *.py LICENSE README.md
+recursive-include audioldm2 *.txt *.py *.gz *.npy *.json
\ No newline at end of file
diff --git a/README.md b/README.md
index d8cd2b7e733713d278e699014b616f91add1cfaa..4d55fdfc128beb20be6f2512363a1c8663bff795 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,23 @@
---
-title: Audioldm2 Text2audio Text2music
-emoji: 👁
-colorFrom: gray
-colorTo: green
+title: AudioLDM2 Text2Audio Text2Music Generation
+emoji: 🔊
+colorFrom: indigo
+colorTo: red
sdk: gradio
-sdk_version: 3.39.0
+sdk_version: 3.27.0
app_file: app.py
pinned: false
-license: cc-by-nc-nd-4.0
+license: bigscience-openrail-m
+duplicated_from: haoheliu/audioldm2-text2audio-text2music
+
---
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+## Reference
+Part of the code from this repo is borrowed from the following repos. We would like to thank the authors of them for their contribution.
+
+> https://github.com/LAION-AI/CLAP
+> https://github.com/CompVis/stable-diffusion
+> https://github.com/v-iashin/SpecVQGAN
+> https://github.com/toshas/torch-fidelity
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2bba800105bda2f7fd171a68f28c31b00a87bd9
--- /dev/null
+++ b/app.py
@@ -0,0 +1,361 @@
+from huggingface_hub import hf_hub_download
+import torch
+import os
+
+os.environ["TOKENIZERS_PARALLELISM"] = "true"
+
+import gradio as gr
+from audioldm2 import text_to_audio, build_model
+from share_btn import community_icon_html, loading_icon_html, share_js
+
+model_id = "haoheliu/audioldm2-full"
+hf_hub_download(repo_id="haoheliu/audioldm2-full", filename="audioldm2-full.pth")
+
+audioldm = None
+current_model_name = None
+
+def text2audio(
+ text,
+ guidance_scale,
+ random_seed,
+ n_candidates,
+ model_name="audioldm2-full",
+):
+ global audioldm, current_model_name
+ torch.set_float32_matmul_precision("high")
+
+ if audioldm is None or model_name != current_model_name:
+ audioldm = build_model(model_name=model_name)
+ current_model_name = model_name
+ audioldm = torch.compile(audioldm)
+
+ # print(text, length, guidance_scale)
+ waveform = text_to_audio(
+ latent_diffusion=audioldm,
+ text=text,
+ seed=random_seed,
+ duration=10,
+ guidance_scale=guidance_scale,
+ n_candidate_gen_per_text=int(n_candidates),
+ ) # [bs, 1, samples]
+ waveform = [
+ gr.make_waveform((16000, wave[0]), bg_image="bg.png") for wave in waveform
+ ]
+ # waveform = [(16000, np.random.randn(16000)), (16000, np.random.randn(16000))]
+ if len(waveform) == 1:
+ waveform = waveform[0]
+ return waveform
+
+css = """
+ a {
+ color: inherit;
+ text-decoration: underline;
+ }
+ .gradio-container {
+ font-family: 'IBM Plex Sans', sans-serif;
+ }
+ .gr-button {
+ color: white;
+ border-color: #000000;
+ background: #000000;
+ }
+ input[type='range'] {
+ accent-color: #000000;
+ }
+ .dark input[type='range'] {
+ accent-color: #dfdfdf;
+ }
+ .container {
+ max-width: 730px;
+ margin: auto;
+ padding-top: 1.5rem;
+ }
+ #gallery {
+ min-height: 22rem;
+ margin-bottom: 15px;
+ margin-left: auto;
+ margin-right: auto;
+ border-bottom-right-radius: .5rem !important;
+ border-bottom-left-radius: .5rem !important;
+ }
+ #gallery>div>.h-full {
+ min-height: 20rem;
+ }
+ .details:hover {
+ text-decoration: underline;
+ }
+ .gr-button {
+ white-space: nowrap;
+ }
+ .gr-button:focus {
+ border-color: rgb(147 197 253 / var(--tw-border-opacity));
+ outline: none;
+ box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
+ --tw-border-opacity: 1;
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
+ --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
+ --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
+ --tw-ring-opacity: .5;
+ }
+ #advanced-btn {
+ font-size: .7rem !important;
+ line-height: 19px;
+ margin-top: 12px;
+ margin-bottom: 12px;
+ padding: 2px 8px;
+ border-radius: 14px !important;
+ }
+ #advanced-options {
+ margin-bottom: 20px;
+ }
+ .footer {
+ margin-bottom: 45px;
+ margin-top: 35px;
+ text-align: center;
+ border-bottom: 1px solid #e5e5e5;
+ }
+ .footer>p {
+ font-size: .8rem;
+ display: inline-block;
+ padding: 0 10px;
+ transform: translateY(10px);
+ background: white;
+ }
+ .dark .footer {
+ border-color: #303030;
+ }
+ .dark .footer>p {
+ background: #0b0f19;
+ }
+ .acknowledgments h4{
+ margin: 1.25em 0 .25em 0;
+ font-weight: bold;
+ font-size: 115%;
+ }
+ #container-advanced-btns{
+ display: flex;
+ flex-wrap: wrap;
+ justify-content: space-between;
+ align-items: center;
+ }
+ .animate-spin {
+ animation: spin 1s linear infinite;
+ }
+ @keyframes spin {
+ from {
+ transform: rotate(0deg);
+ }
+ to {
+ transform: rotate(360deg);
+ }
+ }
+ #share-btn-container {
+ display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
+ margin-top: 10px;
+ margin-left: auto;
+ }
+ #share-btn {
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
+ }
+ #share-btn * {
+ all: unset;
+ }
+ #share-btn-container div:nth-child(-n+2){
+ width: auto !important;
+ min-height: 0px !important;
+ }
+ #share-btn-container .wrap {
+ display: none !important;
+ }
+ .gr-form{
+ flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
+ }
+ #prompt-container{
+ gap: 0;
+ }
+ #generated_id{
+ min-height: 700px
+ }
+ #setting_id{
+ margin-bottom: 12px;
+ text-align: center;
+ font-weight: 900;
+ }
+"""
+iface = gr.Blocks(css=css)
+
+with iface:
+ gr.HTML(
+ """
+
+
+
+ AudioLDM 2: A General Framework for Audio, Music, and Speech Generation
+
+
+
+ [Paper] [Project page]
+
+
+ """
+ )
+ gr.HTML(
+ """
+
+ AudioLDM 2: A General Framework for Audio, Music, and Speech Generation
+
+ For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
+
+
+
+
+ """
+ )
+ with gr.Group():
+ with gr.Box():
+ ############# Input
+ textbox = gr.Textbox(
+ value="A forest of wind chimes singing a soothing melody in the breeze.",
+ max_lines=1,
+ label="Input your text here. Your text is important for the audio quality. Please ensure it is descriptive by using more adjectives.",
+ elem_id="prompt-in",
+ )
+
+ with gr.Accordion("Click to modify detailed configurations", open=False):
+ seed = gr.Number(
+ value=45,
+ label="Change this value (any integer number) will lead to a different generation result.",
+ )
+ # duration = gr.Slider(
+ # 10, 10, value=10, step=2.5, label="Duration (seconds)"
+ # )
+ guidance_scale = gr.Slider(
+ 0,
+ 6,
+ value=3.5,
+ step=0.5,
+ label="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
+ )
+ n_candidates = gr.Slider(
+ 1,
+ 3,
+ value=3,
+ step=1,
+ label="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
+ )
+ # model_name = gr.Dropdown(
+ # ["audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full","audioldm-s-full-v2", "audioldm-s-full", "audioldm-l-full"], value="audioldm-m-full", label="Choose the model to use. audioldm-m-text-ft and audioldm-s-text-ft are recommanded. -s- means small, -m- means medium and -l- means large",
+ # )
+ ############# Output
+ # outputs=gr.Audio(label="Output", type="numpy")
+ outputs = gr.Video(label="Output", elem_id="output-video")
+
+ # with gr.Group(elem_id="container-advanced-btns"):
+ # # advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
+ # with gr.Group(elem_id="share-btn-container"):
+ # community_icon = gr.HTML(community_icon_html, visible=False)
+ # loading_icon = gr.HTML(loading_icon_html, visible=False)
+ # share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
+ # outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")]
+ btn = gr.Button("Submit").style(full_width=True)
+
+ with gr.Group(elem_id="share-btn-container", visible=False):
+ community_icon = gr.HTML(community_icon_html)
+ loading_icon = gr.HTML(loading_icon_html)
+ share_button = gr.Button("Share to community", elem_id="share-btn")
+
+ # btn.click(text2audio, inputs=[
+ # textbox, duration, guidance_scale, seed, n_candidates, model_name], outputs=[outputs])
+ btn.click(
+ text2audio,
+ inputs=[textbox, guidance_scale, seed, n_candidates],
+ outputs=[outputs],
+ )
+
+ share_button.click(None, [], [], _js=share_js)
+ gr.HTML(
+ """
+
+ """
+ )
+ gr.Examples(
+ [
+ [
+ "An excited crowd cheering at a sports game.",
+ 3.5,
+ 45,
+ 3,
+ "audioldm2-full",
+ ],
+ [
+ "A cat is meowing for attention.",
+ 3.5,
+ 45,
+ 3,
+ "audioldm2-full",
+ ],
+ [
+ "Birds singing sweetly in a blooming garden.",
+ 3.5,
+ 45,
+ 3,
+ "audioldm2-full",
+ ],
+ [
+ "A modern synthesizer creating futuristic soundscapes.",
+ 3.5,
+ 45,
+ 3,
+ "audioldm2-full",
+ ],
+ [
+ "The vibrant beat of Brazilian samba drums.",
+ 3.5,
+ 45,
+ 3,
+ "audioldm2-full",
+ ],
+ ],
+ fn=text2audio,
+ # inputs=[textbox, duration, guidance_scale, seed, n_candidates, model_name],
+ inputs=[textbox, guidance_scale, seed, n_candidates],
+ outputs=[outputs],
+ cache_examples=True,
+ )
+ gr.HTML(
+ """
+
+
Essential Tricks for Enhancing the Quality of Your Generated Audio
+
1. Try to use more adjectives to describe your sound. For example: "A man is speaking clearly and slowly in a large room" is better than "A man is speaking". This can make sure AudioLDM understands what you want.
+
2. Try to use different random seeds, which can affect the generation quality significantly sometimes.
+
3. It's better to use general terms like 'man' or 'woman' instead of specific names for individuals or abstract objects that humans may not be familiar with, such as 'mummy'.
+
+ """
+ )
+
+ with gr.Accordion("Additional information", open=False):
+ gr.HTML(
+ """
+
+ """
+ )
+# This demo is strictly for research demo purpose only. For commercial use please contact us .
+
+iface.queue(concurrency_count=3)
+# iface.launch(debug=True)
+iface.launch(debug=True, share=True)
diff --git a/audioldm2/__init__.py b/audioldm2/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..91befda907125b4772601b1df2c9a8a52b733735
--- /dev/null
+++ b/audioldm2/__init__.py
@@ -0,0 +1,2 @@
+from .utils import seed_everything, save_wave, get_time, get_duration, read_list
+from .pipeline import *
diff --git a/audioldm2/audiomae_gen/__init__.py b/audioldm2/audiomae_gen/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..7202889ac7aeb7e5b344da994206715cbbb3891e
--- /dev/null
+++ b/audioldm2/audiomae_gen/__init__.py
@@ -0,0 +1 @@
+from .sequence_input import Sequence2AudioMAE
diff --git a/audioldm2/audiomae_gen/sequence_input.py b/audioldm2/audiomae_gen/sequence_input.py
new file mode 100755
index 0000000000000000000000000000000000000000..4d961a0dd7157689fab6291bb3c40d9bd656b5f1
--- /dev/null
+++ b/audioldm2/audiomae_gen/sequence_input.py
@@ -0,0 +1,429 @@
+import torch
+import torch.nn as nn
+from audioldm2.latent_diffusion.util import (
+ instantiate_from_config,
+)
+
+# from latent_diffusion.modules.encoders.modules import CLAPAudioEmbeddingClassifierFreev2
+from transformers import GPT2Config, GPT2Model
+import torch.optim.lr_scheduler as lr_scheduler
+
+class Sequence2AudioMAE(nn.Module):
+ def __init__(
+ self,
+ base_learning_rate,
+ sequence_gen_length,
+ sequence_input_key,
+ sequence_input_embed_dim,
+ cond_stage_config,
+ optimizer_type="AdamW",
+ use_warmup=True,
+ use_ar_gen_loss=False,
+ use_audiomae_linear=False,
+ target_tokens_mask_ratio=0.0,
+ random_mask_ratio=False,
+ **kwargs
+ ):
+ super().__init__()
+ assert use_audiomae_linear == False
+ self.random_mask_ratio = random_mask_ratio
+ self.learning_rate = base_learning_rate
+ self.cond_stage_config = cond_stage_config
+ self.use_audiomae_linear = use_audiomae_linear
+ self.optimizer_type = optimizer_type
+ self.use_warmup = use_warmup
+ self.use_ar_gen_loss = use_ar_gen_loss
+ # Even though the LDM can be conditioned on mutliple pooling rate
+ # Our model always predict the higest pooling rate
+
+ # self.time_pool = max(self.cond_stage_config["crossattn_audiomae_pooled"]["params"]["time_pooling_factors"])
+ # self.freq_pool = max(self.cond_stage_config["crossattn_audiomae_pooled"]["params"]["freq_pooling_factors"])
+ # self.mae_token_num = int(512/(self.time_pool*self.freq_pool))
+
+ self.mae_token_num = sequence_gen_length
+ self.sequence_input_key = sequence_input_key
+ self.sequence_input_embed_dim = sequence_input_embed_dim
+ self.target_tokens_mask_ratio = target_tokens_mask_ratio
+
+ self.start_of_sequence_tokens = nn.Embedding(32, 768)
+ self.end_of_sequence_tokens = nn.Embedding(32, 768)
+
+ self.input_sequence_embed_linear = nn.ModuleList([])
+ self.initial_learning_rate = None
+
+ for dim in self.sequence_input_embed_dim:
+ self.input_sequence_embed_linear.append(nn.Linear(dim, 768))
+
+ self.cond_stage_models = nn.ModuleList([])
+ self.instantiate_cond_stage(cond_stage_config)
+ self.initialize_param_check_toolkit()
+
+ # configuration = GPT2Config(n_layer=1) # TODO
+ # self.model=GPT2Model(configuration)
+ ###################
+ # self.model=nn.Linear(768,768, bias=False) # TODO change the model
+ # with torch.no_grad():
+ # self.model.weight.copy_(torch.eye(768))
+ ###################
+ self.model = GPT2Model(GPT2Config.from_pretrained("gpt2"))
+ ###################
+ # self.model = nn.LSTM(input_size=768, hidden_size=768, num_layers=1,bias=False) # TODO
+
+ # self.loss_fn = nn.MSELoss()
+ self.loss_fn = nn.L1Loss()
+
+ self.logger_save_dir = None
+ self.logger_exp_name = None
+ self.logger_exp_group_name = None
+ self.logger_version = None
+
+ def set_log_dir(self, save_dir, exp_group_name, exp_name):
+ self.logger_save_dir = save_dir
+ self.logger_exp_group_name = exp_group_name
+ self.logger_exp_name = exp_name
+
+ def cfg_uncond(self, batch_size):
+ unconditional_conditioning = {}
+ for key in self.cond_stage_model_metadata:
+ model_idx = self.cond_stage_model_metadata[key]["model_idx"]
+ unconditional_conditioning[key] = self.cond_stage_models[
+ model_idx
+ ].get_unconditional_condition(batch_size)
+ assert (
+ "crossattn_audiomae_pooled" in unconditional_conditioning.keys()
+ ), "The module is not initialized with AudioMAE"
+ unconditional_conditioning[
+ "crossattn_clap_to_audiomae_feature"
+ ] = unconditional_conditioning["crossattn_audiomae_pooled"]
+ return unconditional_conditioning
+
+ def configure_optimizers(self):
+ lr = float(self.learning_rate)
+ # params = list(self.model.parameters()) + list(self.input_sequence_embed_linear.parameters())
+ params = list(self.parameters())
+
+ # opt = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.98), eps=1e-9)
+ opt = eval(self.optimizer_type)(params, lr=lr)
+ scheduler = lr_scheduler.StepLR(opt, step_size=10, gamma=0.8)
+ return [opt], [scheduler]
+
+ def add_sos_eos_tokens(self, _id, sequence, attn_mask):
+ batchsize = sequence.size(0)
+
+ new_attn_mask_step = torch.ones((batchsize, 1)).to(sequence.device)
+ key_id = torch.tensor([_id]).to(sequence.device)
+
+ # Add two more steps to attn mask
+ new_attn_mask = torch.cat(
+ [new_attn_mask_step, attn_mask, new_attn_mask_step], dim=1
+ )
+
+ # Add two more tokens in the sequence
+ sos_token = self.start_of_sequence_tokens(key_id).expand(batchsize, 1, -1)
+ eos_token = self.end_of_sequence_tokens(key_id).expand(batchsize, 1, -1)
+ new_sequence = torch.cat([sos_token, sequence, eos_token], dim=1)
+ return new_sequence, new_attn_mask
+
+ def truncate_sequence_and_mask(self, sequence, mask, max_len=512):
+ if sequence.size(1) > max_len:
+ print(
+ "The input sequence length to GPT-2 model is too long:",
+ sequence.size(1),
+ )
+ return sequence[:, :max_len], mask[:, :max_len]
+ else:
+ return sequence, mask
+
+ def get_input_sequence_and_mask(self, cond_dict):
+ input_embeds = None
+ input_embeds_attn_mask = None
+ for _id, sequence_key in enumerate(self.sequence_input_key):
+ assert sequence_key in cond_dict.keys(), (
+ "Invalid sequence key %s" % sequence_key
+ )
+ cond_embed = cond_dict[sequence_key]
+ if isinstance(cond_embed, list):
+ assert (
+ len(cond_embed) == 2
+ ), "The crossattn returned list should have length 2, including embed and attn_mask"
+ item_input_embeds, item_attn_mask = cond_embed
+
+ item_input_embeds = self.input_sequence_embed_linear[_id](
+ item_input_embeds
+ )
+
+ item_input_embeds, item_attn_mask = self.add_sos_eos_tokens(
+ _id, item_input_embeds, item_attn_mask
+ )
+
+ if input_embeds is None and input_embeds_attn_mask is None:
+ input_embeds, input_embeds_attn_mask = (
+ item_input_embeds,
+ item_attn_mask,
+ )
+ else:
+ input_embeds = torch.cat(
+ [input_embeds, item_input_embeds], dim=1
+ ) # The 1-st dimension is time steps
+ input_embeds_attn_mask = torch.cat(
+ [input_embeds_attn_mask, item_attn_mask], dim=1
+ ) # The 1-st dimension is time steps
+ else:
+ assert isinstance(cond_embed, torch.Tensor)
+ cond_embed = self.input_sequence_embed_linear[_id](cond_embed)
+ attn_mask = torch.ones((cond_embed.size(0), cond_embed.size(1))).to(
+ cond_embed.device
+ )
+
+ item_input_embeds, item_attn_mask = self.add_sos_eos_tokens(
+ _id, cond_embed, attn_mask
+ )
+
+ if input_embeds is None and input_embeds_attn_mask is None:
+ input_embeds, input_embeds_attn_mask = (
+ item_input_embeds,
+ item_attn_mask,
+ )
+ else:
+ input_embeds, input_embeds_attn_mask = torch.cat(
+ [input_embeds, item_input_embeds], dim=1
+ ), torch.cat([input_embeds_attn_mask, item_attn_mask], dim=1)
+
+ assert input_embeds is not None and input_embeds_attn_mask is not None
+
+ input_embeds, input_embeds_attn_mask = self.truncate_sequence_and_mask(
+ input_embeds, input_embeds_attn_mask, int(1024 - self.mae_token_num)
+ )
+ cond_sequence_end_time_idx = input_embeds.size(
+ 1
+ ) # The index that we start to collect the output embeds
+
+ return input_embeds, input_embeds_attn_mask, cond_sequence_end_time_idx
+
+ def warmup_step(self):
+ if self.initial_learning_rate is None:
+ self.initial_learning_rate = float(self.learning_rate)
+
+ # Only the first parameter group
+ if self.global_step <= 1000:
+ if self.global_step == 0:
+ print(
+ "Warming up learning rate start with %s"
+ % self.initial_learning_rate
+ )
+ self.trainer.optimizers[0].param_groups[0]["lr"] = (
+ self.global_step / 1000
+ ) * self.initial_learning_rate
+ else:
+ # TODO set learning rate here
+ self.trainer.optimizers[0].param_groups[0][
+ "lr"
+ ] = self.initial_learning_rate
+
+ def mask_target_sequence(self, target_embeds, target_embeds_attn_mask):
+ time_seq_mask = None
+ if self.target_tokens_mask_ratio > 1e-4:
+ batchsize, time_seq_len, embed_dim = target_embeds.size()
+ _, time_seq_len = target_embeds_attn_mask.size()
+ # Generate random mask
+ if self.random_mask_ratio:
+ mask_ratio = torch.rand(1).item() * self.target_tokens_mask_ratio
+ else:
+ mask_ratio = self.target_tokens_mask_ratio
+
+ time_seq_mask = (torch.rand((batchsize, time_seq_len)) > mask_ratio).to(
+ target_embeds.device
+ )
+ # Mask the target embedding
+ target_embeds = target_embeds * time_seq_mask.unsqueeze(-1)
+ target_embeds_attn_mask = target_embeds_attn_mask * time_seq_mask
+ return target_embeds, target_embeds_attn_mask, time_seq_mask
+
+ def generate_partial(self, batch, cond_dict=None, no_grad=False):
+ if cond_dict is None:
+ cond_dict = self.get_input(batch)
+
+ print("Generate partially prompted audio with in-context learning")
+ # self.model.train()
+ # assert self.model.training==True
+
+ target_embeds, target_embeds_attn_mask = (
+ cond_dict["crossattn_audiomae_pooled"][0],
+ cond_dict["crossattn_audiomae_pooled"][1],
+ )
+
+ target_time_steps = target_embeds.size(1)
+
+ (
+ input_embeds,
+ input_embeds_attn_mask,
+ cond_sequence_end_time_idx,
+ ) = self.get_input_sequence_and_mask(cond_dict)
+
+ model_input = torch.cat(
+ [input_embeds, target_embeds[:, : target_time_steps // 4, :]], dim=1
+ )
+ model_input_mask = torch.cat(
+ [
+ input_embeds_attn_mask,
+ target_embeds_attn_mask[:, : target_time_steps // 4],
+ ],
+ dim=1,
+ )
+
+ steps = self.mae_token_num
+
+ for _ in range(3 * steps // 4):
+ output = self.model(
+ inputs_embeds=model_input, attention_mask=model_input_mask
+ )["last_hidden_state"]
+ # Update the model input
+ model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
+ # Update the attention mask
+ attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to(
+ model_input.device
+ )
+ model_input_mask = torch.cat(
+ [model_input_mask, attention_mask_new_step], dim=1
+ )
+
+ output = model_input[:, cond_sequence_end_time_idx:]
+
+ return output, cond_dict
+
+ def generate(self, batch, cond_dict=None, no_grad=False):
+ if cond_dict is None:
+ cond_dict = self.get_input(batch)
+
+ # self.model.train()
+ # print("!!!!!!!!!!!!!train")
+
+ (
+ input_embeds,
+ input_embeds_attn_mask,
+ cond_sequence_end_time_idx,
+ ) = self.get_input_sequence_and_mask(cond_dict)
+ model_input = input_embeds
+ model_input_mask = input_embeds_attn_mask
+
+ steps = self.mae_token_num
+
+ for _ in range(steps):
+ output = self.model(
+ inputs_embeds=model_input, attention_mask=model_input_mask
+ )["last_hidden_state"]
+ # Update the model input
+ model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
+ # Update the attention mask
+ attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to(
+ model_input.device
+ )
+ model_input_mask = torch.cat(
+ [model_input_mask, attention_mask_new_step], dim=1
+ )
+
+ return model_input[:, cond_sequence_end_time_idx:], cond_dict
+
+ def get_input_item(self, batch, k):
+ fname, text, waveform, stft, fbank = (
+ batch["fname"],
+ batch["text"],
+ batch["waveform"],
+ batch["stft"],
+ batch["log_mel_spec"],
+ )
+ ret = {}
+
+ ret["fbank"] = (
+ fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float()
+ )
+ ret["stft"] = stft.to(memory_format=torch.contiguous_format).float()
+ # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float()
+ ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float()
+ ret["text"] = list(text)
+ ret["fname"] = fname
+
+ for key in batch.keys():
+ if key not in ret.keys():
+ ret[key] = batch[key]
+
+ return ret[k]
+
+ def get_input(self, batch):
+ cond_dict = {}
+ if len(self.cond_stage_model_metadata.keys()) > 0:
+ unconditional_cfg = False
+
+ for cond_model_key in self.cond_stage_model_metadata.keys():
+ cond_stage_key = self.cond_stage_model_metadata[cond_model_key][
+ "cond_stage_key"
+ ]
+
+ # if(not self.training):
+ # if(isinstance(self.cond_stage_models[self.cond_stage_model_metadata[cond_model_key]["model_idx"]], CLAPAudioEmbeddingClassifierFreev2)):
+ # assert cond_stage_key == "text" # CLAP model should use text for evaluation
+
+ # The original data for conditioning
+ xc = self.get_input_item(batch, cond_stage_key)
+ if type(xc) == torch.Tensor:
+ xc = xc.to(self.device)
+
+ c = self.get_learned_conditioning(
+ xc, key=cond_model_key, unconditional_cfg=unconditional_cfg
+ )
+ cond_dict[cond_model_key] = c
+
+ return cond_dict
+
+ def instantiate_cond_stage(self, config):
+ self.cond_stage_model_metadata = {}
+
+ for i, cond_model_key in enumerate(config.keys()):
+ model = instantiate_from_config(config[cond_model_key])
+ self.cond_stage_models.append(model)
+ self.cond_stage_model_metadata[cond_model_key] = {
+ "model_idx": i,
+ "cond_stage_key": config[cond_model_key]["cond_stage_key"],
+ "conditioning_key": config[cond_model_key]["conditioning_key"],
+ }
+
+ def get_learned_conditioning(self, c, key, unconditional_cfg):
+ assert key in self.cond_stage_model_metadata.keys()
+
+ # Classifier-free guidance
+ if not unconditional_cfg:
+ c = self.cond_stage_models[
+ self.cond_stage_model_metadata[key]["model_idx"]
+ ](c)
+ else:
+ if isinstance(c, torch.Tensor):
+ batchsize = c.size(0)
+ elif isinstance(c, list):
+ batchsize = len(c)
+ else:
+ raise NotImplementedError()
+ c = self.cond_stage_models[
+ self.cond_stage_model_metadata[key]["model_idx"]
+ ].get_unconditional_condition(batchsize)
+
+ return c
+
+ def initialize_param_check_toolkit(self):
+ self.tracked_steps = 0
+ self.param_dict = {}
+
+ def statistic_require_grad_tensor_number(self, module, name=None):
+ requires_grad_num = 0
+ total_num = 0
+ require_grad_tensor = None
+ for p in module.parameters():
+ if p.requires_grad:
+ requires_grad_num += 1
+ if require_grad_tensor is None:
+ require_grad_tensor = p
+ total_num += 1
+ print(
+ "Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)"
+ % (name, requires_grad_num, total_num, requires_grad_num / total_num)
+ )
+ return require_grad_tensor
diff --git a/audioldm2/audiomae_gen/utils.py b/audioldm2/audiomae_gen/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..841d35adf338647bdf8bd1c31e9f33dee1252b6e
--- /dev/null
+++ b/audioldm2/audiomae_gen/utils.py
@@ -0,0 +1,27 @@
+import torch.nn as nn
+
+
+class Prenet(nn.Module):
+ def __init__(self, in_dim, sizes=[256, 128], dropout_rate=0.5):
+ super(Prenet, self).__init__()
+ in_sizes = [in_dim] + sizes[:-1]
+ self.layers = nn.ModuleList(
+ [
+ nn.Linear(in_size, out_size)
+ for (in_size, out_size) in zip(in_sizes, sizes)
+ ]
+ )
+ self.relu = nn.ReLU()
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def forward(self, inputs):
+ for linear in self.layers:
+ inputs = self.dropout(self.relu(linear(inputs)))
+ return inputs
+
+
+if __name__ == "__main__":
+ model = Prenet(in_dim=128, sizes=[256, 256, 128])
+ import ipdb
+
+ ipdb.set_trace()
diff --git a/audioldm2/clap/__init__.py b/audioldm2/clap/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/clap/open_clip/__init__.py b/audioldm2/clap/open_clip/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e9f728f2f273be5d5fdbec6c6cc41d737176a8c0
--- /dev/null
+++ b/audioldm2/clap/open_clip/__init__.py
@@ -0,0 +1,25 @@
+from .factory import (
+ list_models,
+ create_model,
+ create_model_and_transforms,
+ add_model_config,
+)
+from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
+from .model import (
+ CLAP,
+ CLAPTextCfg,
+ CLAPVisionCfg,
+ CLAPAudioCfp,
+ convert_weights_to_fp16,
+ trace_model,
+)
+from .openai import load_openai_model, list_openai_models
+from .pretrained import (
+ list_pretrained,
+ list_pretrained_tag_models,
+ list_pretrained_model_tags,
+ get_pretrained_url,
+ download_pretrained,
+)
+from .tokenizer import SimpleTokenizer, tokenize
+from .transform import image_transform
diff --git a/audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz b/audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz
new file mode 100755
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/audioldm2/clap/open_clip/factory.py b/audioldm2/clap/open_clip/factory.py
new file mode 100755
index 0000000000000000000000000000000000000000..df0f4a194c2e7328f7b7d3fe11fa6801c6cc1a7c
--- /dev/null
+++ b/audioldm2/clap/open_clip/factory.py
@@ -0,0 +1,276 @@
+import json
+import logging
+import os
+import re
+from copy import deepcopy
+from pathlib import Path
+
+import torch
+
+from .model import CLAP, convert_weights_to_fp16
+from .openai import load_openai_model
+from .pretrained import get_pretrained_url, download_pretrained
+from .transform import image_transform
+
+_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
+_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
+
+
+def _natural_key(string_):
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
+
+
+def _rescan_model_configs():
+ global _MODEL_CONFIGS
+
+ config_ext = (".json",)
+ config_files = []
+ for config_path in _MODEL_CONFIG_PATHS:
+ if config_path.is_file() and config_path.suffix in config_ext:
+ config_files.append(config_path)
+ elif config_path.is_dir():
+ for ext in config_ext:
+ config_files.extend(config_path.glob(f"*{ext}"))
+
+ for cf in config_files:
+ if os.path.basename(cf)[0] == ".":
+ continue # Ignore hidden files
+
+ with open(cf, "r") as f:
+ model_cfg = json.load(f)
+ if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
+ _MODEL_CONFIGS[cf.stem] = model_cfg
+
+ _MODEL_CONFIGS = {
+ k: v
+ for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
+ }
+
+
+_rescan_model_configs() # initial populate of model config registry
+
+
+def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
+ state_dict = checkpoint["state_dict"]
+ else:
+ state_dict = checkpoint
+ if skip_params:
+ if next(iter(state_dict.items()))[0].startswith("module"):
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
+ # for k in state_dict:
+ # if k.startswith('transformer'):
+ # v = state_dict.pop(k)
+ # state_dict['text_branch.' + k[12:]] = v
+ return state_dict
+
+
+def create_model(
+ amodel_name: str,
+ tmodel_name: str,
+ pretrained: str = "",
+ precision: str = "fp32",
+ device: torch.device = torch.device("cpu"),
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
+ skip_params=True,
+ pretrained_audio: str = "",
+ pretrained_text: str = "",
+ enable_fusion: bool = False,
+ fusion_type: str = "None"
+ # pretrained_image: bool = False,
+):
+ amodel_name = amodel_name.replace(
+ "/", "-"
+ ) # for callers using old naming with / in ViT names
+ pretrained_orig = pretrained
+ pretrained = pretrained.lower()
+ if pretrained == "openai":
+ if amodel_name in _MODEL_CONFIGS:
+ logging.info(f"Loading {amodel_name} model config.")
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
+ else:
+ logging.error(
+ f"Model config for {amodel_name} not found; available models {list_models()}."
+ )
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
+
+ logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
+ # Hard Code in model name
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
+ model = load_openai_model(
+ "ViT-B-16",
+ model_cfg,
+ device=device,
+ jit=jit,
+ cache_dir=openai_model_cache_dir,
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type,
+ )
+ # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
+ if precision == "amp" or precision == "fp32":
+ model = model.float()
+ else:
+ if amodel_name in _MODEL_CONFIGS:
+ logging.info(f"Loading {amodel_name} model config.")
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
+ else:
+ logging.error(
+ f"Model config for {amodel_name} not found; available models {list_models()}."
+ )
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
+
+ if force_quick_gelu:
+ # override for use of QuickGELU on non-OpenAI transformer models
+ model_cfg["quick_gelu"] = True
+
+ # if pretrained_image:
+ # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
+ # # pretrained weight loading for timm models set via vision_cfg
+ # model_cfg['vision_cfg']['timm_model_pretrained'] = True
+ # else:
+ # assert False, 'pretrained image towers currently only supported for timm models'
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
+ model_cfg["enable_fusion"] = enable_fusion
+ model_cfg["fusion_type"] = fusion_type
+ model = CLAP(**model_cfg)
+
+ if pretrained:
+ checkpoint_path = ""
+ url = get_pretrained_url(amodel_name, pretrained)
+ if url:
+ checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
+ elif os.path.exists(pretrained_orig):
+ checkpoint_path = pretrained_orig
+ if checkpoint_path:
+ logging.info(
+ f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
+ )
+ ckpt = load_state_dict(checkpoint_path, skip_params=True)
+ model.load_state_dict(ckpt)
+ param_names = [n for n, p in model.named_parameters()]
+ # for n in param_names:
+ # print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
+ else:
+ logging.warning(
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
+ )
+ raise RuntimeError(
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
+ )
+
+ if pretrained_audio:
+ if amodel_name.startswith("PANN"):
+ if "Cnn14_mAP" in pretrained_audio: # official checkpoint
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+ audio_ckpt = audio_ckpt["model"]
+ keys = list(audio_ckpt.keys())
+ for key in keys:
+ if (
+ "spectrogram_extractor" not in key
+ and "logmel_extractor" not in key
+ ):
+ v = audio_ckpt.pop(key)
+ audio_ckpt["audio_branch." + key] = v
+ elif os.path.basename(pretrained_audio).startswith(
+ "PANN"
+ ): # checkpoint trained via HTSAT codebase
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+ audio_ckpt = audio_ckpt["state_dict"]
+ keys = list(audio_ckpt.keys())
+ for key in keys:
+ if key.startswith("sed_model"):
+ v = audio_ckpt.pop(key)
+ audio_ckpt["audio_branch." + key[10:]] = v
+ elif os.path.basename(pretrained_audio).startswith(
+ "finetuned"
+ ): # checkpoint trained via linear probe codebase
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+ else:
+ raise ValueError("Unknown audio checkpoint")
+ elif amodel_name.startswith("HTSAT"):
+ if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+ audio_ckpt = audio_ckpt["state_dict"]
+ keys = list(audio_ckpt.keys())
+ for key in keys:
+ if key.startswith("sed_model") and (
+ "spectrogram_extractor" not in key
+ and "logmel_extractor" not in key
+ ):
+ v = audio_ckpt.pop(key)
+ audio_ckpt["audio_branch." + key[10:]] = v
+ elif os.path.basename(pretrained_audio).startswith(
+ "HTSAT"
+ ): # checkpoint trained via HTSAT codebase
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+ audio_ckpt = audio_ckpt["state_dict"]
+ keys = list(audio_ckpt.keys())
+ for key in keys:
+ if key.startswith("sed_model"):
+ v = audio_ckpt.pop(key)
+ audio_ckpt["audio_branch." + key[10:]] = v
+ elif os.path.basename(pretrained_audio).startswith(
+ "finetuned"
+ ): # checkpoint trained via linear probe codebase
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+ else:
+ raise ValueError("Unknown audio checkpoint")
+ else:
+ raise f"this audio encoder pretrained checkpoint is not support"
+
+ model.load_state_dict(audio_ckpt, strict=False)
+ logging.info(
+ f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
+ )
+ param_names = [n for n, p in model.named_parameters()]
+ for n in param_names:
+ print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
+
+ model.to(device=device)
+ if precision == "fp16":
+ assert device.type != "cpu"
+ convert_weights_to_fp16(model)
+
+ if jit:
+ model = torch.jit.script(model)
+
+ return model, model_cfg
+
+
+def create_model_and_transforms(
+ model_name: str,
+ pretrained: str = "",
+ precision: str = "fp32",
+ device: torch.device = torch.device("cpu"),
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ # pretrained_image: bool = False,
+):
+ model = create_model(
+ model_name,
+ pretrained,
+ precision,
+ device,
+ jit,
+ force_quick_gelu=force_quick_gelu,
+ # pretrained_image=pretrained_image
+ )
+ preprocess_train = image_transform(model.visual.image_size, is_train=True)
+ preprocess_val = image_transform(model.visual.image_size, is_train=False)
+ return model, preprocess_train, preprocess_val
+
+
+def list_models():
+ """enumerate available model architectures based on config files"""
+ return list(_MODEL_CONFIGS.keys())
+
+
+def add_model_config(path):
+ """add model config path or file and update registry"""
+ if not isinstance(path, Path):
+ path = Path(path)
+ _MODEL_CONFIG_PATHS.append(path)
+ _rescan_model_configs()
diff --git a/audioldm2/clap/open_clip/feature_fusion.py b/audioldm2/clap/open_clip/feature_fusion.py
new file mode 100755
index 0000000000000000000000000000000000000000..dbe4e170e05894c12ebdc36ba1dc1de65e441b89
--- /dev/null
+++ b/audioldm2/clap/open_clip/feature_fusion.py
@@ -0,0 +1,192 @@
+"""
+Feature Fusion for Varible-Length Data Processing
+AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
+According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
+"""
+
+import torch
+import torch.nn as nn
+
+
+class DAF(nn.Module):
+ """
+ 直接相加 DirectAddFuse
+ """
+
+ def __init__(self):
+ super(DAF, self).__init__()
+
+ def forward(self, x, residual):
+ return x + residual
+
+
+class iAFF(nn.Module):
+ """
+ 多特征融合 iAFF
+ """
+
+ def __init__(self, channels=64, r=4, type="2D"):
+ super(iAFF, self).__init__()
+ inter_channels = int(channels // r)
+
+ if type == "1D":
+ # 本地注意力
+ self.local_att = nn.Sequential(
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+
+ # 全局注意力
+ self.global_att = nn.Sequential(
+ nn.AdaptiveAvgPool1d(1),
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+
+ # 第二次本地注意力
+ self.local_att2 = nn.Sequential(
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+ # 第二次全局注意力
+ self.global_att2 = nn.Sequential(
+ nn.AdaptiveAvgPool1d(1),
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+ elif type == "2D":
+ # 本地注意力
+ self.local_att = nn.Sequential(
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+
+ # 全局注意力
+ self.global_att = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+
+ # 第二次本地注意力
+ self.local_att2 = nn.Sequential(
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+ # 第二次全局注意力
+ self.global_att2 = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+ else:
+ raise f"the type is not supported"
+
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x, residual):
+ flag = False
+ xa = x + residual
+ if xa.size(0) == 1:
+ xa = torch.cat([xa, xa], dim=0)
+ flag = True
+ xl = self.local_att(xa)
+ xg = self.global_att(xa)
+ xlg = xl + xg
+ wei = self.sigmoid(xlg)
+ xi = x * wei + residual * (1 - wei)
+
+ xl2 = self.local_att2(xi)
+ xg2 = self.global_att(xi)
+ xlg2 = xl2 + xg2
+ wei2 = self.sigmoid(xlg2)
+ xo = x * wei2 + residual * (1 - wei2)
+ if flag:
+ xo = xo[0].unsqueeze(0)
+ return xo
+
+
+class AFF(nn.Module):
+ """
+ 多特征融合 AFF
+ """
+
+ def __init__(self, channels=64, r=4, type="2D"):
+ super(AFF, self).__init__()
+ inter_channels = int(channels // r)
+
+ if type == "1D":
+ self.local_att = nn.Sequential(
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+ self.global_att = nn.Sequential(
+ nn.AdaptiveAvgPool1d(1),
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+ elif type == "2D":
+ self.local_att = nn.Sequential(
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+ self.global_att = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+ else:
+ raise f"the type is not supported."
+
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x, residual):
+ flag = False
+ xa = x + residual
+ if xa.size(0) == 1:
+ xa = torch.cat([xa, xa], dim=0)
+ flag = True
+ xl = self.local_att(xa)
+ xg = self.global_att(xa)
+ xlg = xl + xg
+ wei = self.sigmoid(xlg)
+ xo = 2 * x * wei + 2 * residual * (1 - wei)
+ if flag:
+ xo = xo[0].unsqueeze(0)
+ return xo
diff --git a/audioldm2/clap/open_clip/htsat.py b/audioldm2/clap/open_clip/htsat.py
new file mode 100755
index 0000000000000000000000000000000000000000..8bf4fceea2dfef953522c14a3a39a417658f2257
--- /dev/null
+++ b/audioldm2/clap/open_clip/htsat.py
@@ -0,0 +1,1304 @@
+# Ke Chen
+# knutchen@ucsd.edu
+# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
+# Some layers designed on the model
+# below codes are based and referred from https://github.com/microsoft/Swin-Transformer
+# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
+
+import torch
+import torch.nn as nn
+from itertools import repeat
+import collections.abc
+import math
+import warnings
+
+from torch.nn.init import _calculate_fan_in_and_fan_out
+import torch.utils.checkpoint as checkpoint
+
+import random
+
+from torchlibrosa.stft import Spectrogram, LogmelFilterBank
+from torchlibrosa.augmentation import SpecAugmentation
+
+from itertools import repeat
+from .utils import do_mixup, interpolate
+
+from .feature_fusion import iAFF, AFF, DAF
+
+
+# from PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ norm_layer=None,
+ flatten=True,
+ patch_stride=16,
+ enable_fusion=False,
+ fusion_type="None",
+ ):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patch_stride = to_2tuple(patch_stride)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patch_stride = patch_stride
+ self.grid_size = (
+ img_size[0] // patch_stride[0],
+ img_size[1] // patch_stride[1],
+ )
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.flatten = flatten
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+
+ padding = (
+ (patch_size[0] - patch_stride[0]) // 2,
+ (patch_size[1] - patch_stride[1]) // 2,
+ )
+
+ if (self.enable_fusion) and (self.fusion_type == "channel_map"):
+ self.proj = nn.Conv2d(
+ in_chans * 4,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=patch_stride,
+ padding=padding,
+ )
+ else:
+ self.proj = nn.Conv2d(
+ in_chans,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=patch_stride,
+ padding=padding,
+ )
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ if (self.enable_fusion) and (
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
+ ):
+ self.mel_conv2d = nn.Conv2d(
+ in_chans,
+ embed_dim,
+ kernel_size=(patch_size[0], patch_size[1] * 3),
+ stride=(patch_stride[0], patch_stride[1] * 3),
+ padding=padding,
+ )
+ if self.fusion_type == "daf_2d":
+ self.fusion_model = DAF()
+ elif self.fusion_type == "aff_2d":
+ self.fusion_model = AFF(channels=embed_dim, type="2D")
+ elif self.fusion_type == "iaff_2d":
+ self.fusion_model = iAFF(channels=embed_dim, type="2D")
+
+ def forward(self, x, longer_idx=None):
+ if (self.enable_fusion) and (
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
+ ):
+ global_x = x[:, 0:1, :, :]
+
+ # global processing
+ B, C, H, W = global_x.shape
+ assert (
+ H == self.img_size[0] and W == self.img_size[1]
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ global_x = self.proj(global_x)
+ TW = global_x.size(-1)
+ if len(longer_idx) > 0:
+ # local processing
+ local_x = x[longer_idx, 1:, :, :].contiguous()
+ B, C, H, W = local_x.shape
+ local_x = local_x.view(B * C, 1, H, W)
+ local_x = self.mel_conv2d(local_x)
+ local_x = local_x.view(
+ B, C, local_x.size(1), local_x.size(2), local_x.size(3)
+ )
+ local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3)
+ TB, TC, TH, _ = local_x.size()
+ if local_x.size(-1) < TW:
+ local_x = torch.cat(
+ [
+ local_x,
+ torch.zeros(
+ (TB, TC, TH, TW - local_x.size(-1)),
+ device=global_x.device,
+ ),
+ ],
+ dim=-1,
+ )
+ else:
+ local_x = local_x[:, :, :, :TW]
+
+ global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x)
+ x = global_x
+ else:
+ B, C, H, W = x.shape
+ assert (
+ H == self.img_size[0] and W == self.img_size[1]
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x)
+
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
+ x = self.norm(x)
+ return x
+
+
+class Mlp(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2,
+ )
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.0))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
+ # type: (Tensor, float, float, float, float) -> Tensor
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+ if mode == "fan_in":
+ denom = fan_in
+ elif mode == "fan_out":
+ denom = fan_out
+ elif mode == "fan_avg":
+ denom = (fan_in + fan_out) / 2
+
+ variance = scale / denom
+
+ if distribution == "truncated_normal":
+ # constant is stddev of standard normal truncated to (-2, 2)
+ trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
+ elif distribution == "normal":
+ tensor.normal_(std=math.sqrt(variance))
+ elif distribution == "uniform":
+ bound = math.sqrt(3 * variance)
+ tensor.uniform_(-bound, bound)
+ else:
+ raise ValueError(f"invalid distribution {distribution}")
+
+
+def lecun_normal_(tensor):
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = (
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ )
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(
+ B, H // window_size, W // window_size, window_size, window_size, -1
+ )
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(
+ self,
+ dim,
+ window_size,
+ num_heads,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = (
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ ) # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1],
+ self.window_size[0] * self.window_size[1],
+ -1,
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
+ 1
+ ).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x, attn
+
+ def extra_repr(self):
+ return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
+
+
+# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
+class SwinTransformerBlock(nn.Module):
+ r"""Swin Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(
+ self,
+ dim,
+ input_resolution,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ norm_before_mlp="ln",
+ ):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ self.norm_before_mlp = norm_before_mlp
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert (
+ 0 <= self.shift_size < self.window_size
+ ), "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim,
+ window_size=to_2tuple(self.window_size),
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ if self.norm_before_mlp == "ln":
+ self.norm2 = nn.LayerNorm(dim)
+ elif self.norm_before_mlp == "bn":
+ self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(
+ 1, 2
+ )
+ else:
+ raise NotImplementedError
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ H, W = self.input_resolution
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ w_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(
+ img_mask, self.window_size
+ ) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(
+ attn_mask != 0, float(-100.0)
+ ).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+
+ self.register_buffer("attn_mask", attn_mask)
+
+ def forward(self, x):
+ # pdb.set_trace()
+ H, W = self.input_resolution
+ # print("H: ", H)
+ # print("W: ", W)
+ # pdb.set_trace()
+ B, L, C = x.shape
+ # assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
+ )
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(
+ shifted_x, self.window_size
+ ) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(
+ -1, self.window_size * self.window_size, C
+ ) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows, attn = self.attn(
+ x_windows, mask=self.attn_mask
+ ) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
+ )
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x, attn
+
+ def extra_repr(self):
+ return (
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+ )
+
+
+class PatchMerging(nn.Module):
+ r"""Patch Merging Layer.
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+ def extra_repr(self):
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+
+class BasicLayer(nn.Module):
+ """A basic Swin Transformer layer for one stage.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(
+ self,
+ dim,
+ input_resolution,
+ depth,
+ num_heads,
+ window_size,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ norm_before_mlp="ln",
+ ):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList(
+ [
+ SwinTransformerBlock(
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i]
+ if isinstance(drop_path, list)
+ else drop_path,
+ norm_layer=norm_layer,
+ norm_before_mlp=norm_before_mlp,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(
+ input_resolution, dim=dim, norm_layer=norm_layer
+ )
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ attns = []
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x, attn = blk(x)
+ if not self.training:
+ attns.append(attn.unsqueeze(0))
+ if self.downsample is not None:
+ x = self.downsample(x)
+ if not self.training:
+ attn = torch.cat(attns, dim=0)
+ attn = torch.mean(attn, dim=0)
+ return x, attn
+
+ def extra_repr(self):
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+
+# The Core of HTSAT
+class HTSAT_Swin_Transformer(nn.Module):
+ r"""HTSAT based on the Swin Transformer
+ Args:
+ spec_size (int | tuple(int)): Input Spectrogram size. Default 256
+ patch_size (int | tuple(int)): Patch size. Default: 4
+ path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
+ in_chans (int): Number of input image channels. Default: 1 (mono)
+ num_classes (int): Number of classes for classification head. Default: 527
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 8
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ config (module): The configuration Module from config.py
+ """
+
+ def __init__(
+ self,
+ spec_size=256,
+ patch_size=4,
+ patch_stride=(4, 4),
+ in_chans=1,
+ num_classes=527,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=8,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ use_checkpoint=False,
+ norm_before_mlp="ln",
+ config=None,
+ enable_fusion=False,
+ fusion_type="None",
+ **kwargs,
+ ):
+ super(HTSAT_Swin_Transformer, self).__init__()
+
+ self.config = config
+ self.spec_size = spec_size
+ self.patch_stride = patch_stride
+ self.patch_size = patch_size
+ self.window_size = window_size
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.ape = ape
+ self.in_chans = in_chans
+ self.num_classes = num_classes
+ self.num_heads = num_heads
+ self.num_layers = len(self.depths)
+ self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
+
+ self.drop_rate = drop_rate
+ self.attn_drop_rate = attn_drop_rate
+ self.drop_path_rate = drop_path_rate
+
+ self.qkv_bias = qkv_bias
+ self.qk_scale = None
+
+ self.patch_norm = patch_norm
+ self.norm_layer = norm_layer if self.patch_norm else None
+ self.norm_before_mlp = norm_before_mlp
+ self.mlp_ratio = mlp_ratio
+
+ self.use_checkpoint = use_checkpoint
+
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+
+ # process mel-spec ; used only once
+ self.freq_ratio = self.spec_size // self.config.mel_bins
+ window = "hann"
+ center = True
+ pad_mode = "reflect"
+ ref = 1.0
+ amin = 1e-10
+ top_db = None
+ self.interpolate_ratio = 32 # Downsampled ratio
+ # Spectrogram extractor
+ self.spectrogram_extractor = Spectrogram(
+ n_fft=config.window_size,
+ hop_length=config.hop_size,
+ win_length=config.window_size,
+ window=window,
+ center=center,
+ pad_mode=pad_mode,
+ freeze_parameters=True,
+ )
+ # Logmel feature extractor
+ self.logmel_extractor = LogmelFilterBank(
+ sr=config.sample_rate,
+ n_fft=config.window_size,
+ n_mels=config.mel_bins,
+ fmin=config.fmin,
+ fmax=config.fmax,
+ ref=ref,
+ amin=amin,
+ top_db=top_db,
+ freeze_parameters=True,
+ )
+ # Spec augmenter
+ self.spec_augmenter = SpecAugmentation(
+ time_drop_width=64,
+ time_stripes_num=2,
+ freq_drop_width=8,
+ freq_stripes_num=2,
+ ) # 2 2
+ self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
+
+ # split spctrogram into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=self.spec_size,
+ patch_size=self.patch_size,
+ in_chans=self.in_chans,
+ embed_dim=self.embed_dim,
+ norm_layer=self.norm_layer,
+ patch_stride=patch_stride,
+ enable_fusion=self.enable_fusion,
+ fusion_type=self.fusion_type,
+ )
+
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.grid_size
+ self.patches_resolution = patches_resolution
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches, self.embed_dim)
+ )
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
+
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
+
+ # stochastic depth
+ dpr = [
+ x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))
+ ] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=int(self.embed_dim * 2**i_layer),
+ input_resolution=(
+ patches_resolution[0] // (2**i_layer),
+ patches_resolution[1] // (2**i_layer),
+ ),
+ depth=self.depths[i_layer],
+ num_heads=self.num_heads[i_layer],
+ window_size=self.window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=self.qkv_bias,
+ qk_scale=self.qk_scale,
+ drop=self.drop_rate,
+ attn_drop=self.attn_drop_rate,
+ drop_path=dpr[
+ sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1])
+ ],
+ norm_layer=self.norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint,
+ norm_before_mlp=self.norm_before_mlp,
+ )
+ self.layers.append(layer)
+
+ self.norm = self.norm_layer(self.num_features)
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+ self.maxpool = nn.AdaptiveMaxPool1d(1)
+
+ SF = (
+ self.spec_size
+ // (2 ** (len(self.depths) - 1))
+ // self.patch_stride[0]
+ // self.freq_ratio
+ )
+ self.tscam_conv = nn.Conv2d(
+ in_channels=self.num_features,
+ out_channels=self.num_classes,
+ kernel_size=(SF, 3),
+ padding=(0, 1),
+ )
+ self.head = nn.Linear(num_classes, num_classes)
+
+ if (self.enable_fusion) and (
+ self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
+ ):
+ self.mel_conv1d = nn.Sequential(
+ nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
+ nn.BatchNorm1d(64),
+ )
+ if self.fusion_type == "daf_1d":
+ self.fusion_model = DAF()
+ elif self.fusion_type == "aff_1d":
+ self.fusion_model = AFF(channels=64, type="1D")
+ elif self.fusion_type == "iaff_1d":
+ self.fusion_model = iAFF(channels=64, type="1D")
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {"absolute_pos_embed"}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {"relative_position_bias_table"}
+
+ def forward_features(self, x, longer_idx=None):
+ # A deprecated optimization for using a hierarchical output from different blocks
+
+ frames_num = x.shape[2]
+ x = self.patch_embed(x, longer_idx=longer_idx)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+ for i, layer in enumerate(self.layers):
+ x, attn = layer(x)
+ # for x
+ x = self.norm(x)
+ B, N, C = x.shape
+ SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
+ ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
+ x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST)
+ B, C, F, T = x.shape
+ # group 2D CNN
+ c_freq_bin = F // self.freq_ratio
+ x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
+ x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1)
+ # get latent_output
+ fine_grained_latent_output = torch.mean(x, dim=2)
+ fine_grained_latent_output = interpolate(
+ fine_grained_latent_output.permute(0, 2, 1).contiguous(),
+ 8 * self.patch_stride[1],
+ )
+
+ latent_output = self.avgpool(torch.flatten(x, 2))
+ latent_output = torch.flatten(latent_output, 1)
+
+ # display the attention map, if needed
+
+ x = self.tscam_conv(x)
+ x = torch.flatten(x, 2) # B, C, T
+
+ fpx = interpolate(
+ torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1]
+ )
+
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+
+ output_dict = {
+ "framewise_output": fpx, # already sigmoided
+ "clipwise_output": torch.sigmoid(x),
+ "fine_grained_embedding": fine_grained_latent_output,
+ "embedding": latent_output,
+ }
+
+ return output_dict
+
+ def crop_wav(self, x, crop_size, spe_pos=None):
+ time_steps = x.shape[2]
+ tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
+ for i in range(len(x)):
+ if spe_pos is None:
+ crop_pos = random.randint(0, time_steps - crop_size - 1)
+ else:
+ crop_pos = spe_pos
+ tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :]
+ return tx
+
+ # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
+ def reshape_wav2img(self, x):
+ B, C, T, F = x.shape
+ target_T = int(self.spec_size * self.freq_ratio)
+ target_F = self.spec_size // self.freq_ratio
+ assert (
+ T <= target_T and F <= target_F
+ ), "the wav size should less than or equal to the swin input size"
+ # to avoid bicubic zero error
+ if T < target_T:
+ x = nn.functional.interpolate(
+ x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
+ )
+ if F < target_F:
+ x = nn.functional.interpolate(
+ x, (x.shape[2], target_F), mode="bicubic", align_corners=True
+ )
+ x = x.permute(0, 1, 3, 2).contiguous()
+ x = x.reshape(
+ x.shape[0],
+ x.shape[1],
+ x.shape[2],
+ self.freq_ratio,
+ x.shape[3] // self.freq_ratio,
+ )
+ # print(x.shape)
+ x = x.permute(0, 1, 3, 2, 4).contiguous()
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
+ return x
+
+ # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
+ def repeat_wat2img(self, x, cur_pos):
+ B, C, T, F = x.shape
+ target_T = int(self.spec_size * self.freq_ratio)
+ target_F = self.spec_size // self.freq_ratio
+ assert (
+ T <= target_T and F <= target_F
+ ), "the wav size should less than or equal to the swin input size"
+ # to avoid bicubic zero error
+ if T < target_T:
+ x = nn.functional.interpolate(
+ x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
+ )
+ if F < target_F:
+ x = nn.functional.interpolate(
+ x, (x.shape[2], target_F), mode="bicubic", align_corners=True
+ )
+ x = x.permute(0, 1, 3, 2).contiguous() # B C F T
+ x = x[:, :, :, cur_pos : cur_pos + self.spec_size]
+ x = x.repeat(repeats=(1, 1, 4, 1))
+ return x
+
+ def forward(
+ self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None
+ ): # out_feat_keys: List[str] = None):
+ if self.enable_fusion and x["longer"].sum() == 0:
+ # if no audio is longer than 10s, then randomly select one audio to be longer
+ x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True
+
+ if not self.enable_fusion:
+ x = x["waveform"].to(device=device, non_blocking=True)
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+ if self.training:
+ x = self.spec_augmenter(x)
+
+ if self.training and mixup_lambda is not None:
+ x = do_mixup(x, mixup_lambda)
+
+ x = self.reshape_wav2img(x)
+ output_dict = self.forward_features(x)
+ else:
+ longer_list = x["longer"].to(device=device, non_blocking=True)
+ x = x["mel_fusion"].to(device=device, non_blocking=True)
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+ longer_list_idx = torch.where(longer_list)[0]
+ if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
+ new_x = x[:, 0:1, :, :].clone().contiguous()
+ if len(longer_list_idx) > 0:
+ # local processing
+ fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
+ FB, FC, FT, FF = fusion_x_local.size()
+ fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
+ fusion_x_local = torch.permute(
+ fusion_x_local, (0, 2, 1)
+ ).contiguous()
+ fusion_x_local = self.mel_conv1d(fusion_x_local)
+ fusion_x_local = fusion_x_local.view(
+ FB, FC, FF, fusion_x_local.size(-1)
+ )
+ fusion_x_local = (
+ torch.permute(fusion_x_local, (0, 2, 1, 3))
+ .contiguous()
+ .flatten(2)
+ )
+ if fusion_x_local.size(-1) < FT:
+ fusion_x_local = torch.cat(
+ [
+ fusion_x_local,
+ torch.zeros(
+ (FB, FF, FT - fusion_x_local.size(-1)),
+ device=device,
+ ),
+ ],
+ dim=-1,
+ )
+ else:
+ fusion_x_local = fusion_x_local[:, :, :FT]
+ # 1D fusion
+ new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
+ new_x[longer_list_idx] = self.fusion_model(
+ new_x[longer_list_idx], fusion_x_local
+ )
+ x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
+ else:
+ x = new_x
+
+ elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
+ x = x # no change
+
+ if self.training:
+ x = self.spec_augmenter(x)
+ if self.training and mixup_lambda is not None:
+ x = do_mixup(x, mixup_lambda)
+
+ x = self.reshape_wav2img(x)
+ output_dict = self.forward_features(x, longer_idx=longer_list_idx)
+
+ # if infer_mode:
+ # # in infer mode. we need to handle different length audio input
+ # frame_num = x.shape[2]
+ # target_T = int(self.spec_size * self.freq_ratio)
+ # repeat_ratio = math.floor(target_T / frame_num)
+ # x = x.repeat(repeats=(1,1,repeat_ratio,1))
+ # x = self.reshape_wav2img(x)
+ # output_dict = self.forward_features(x)
+ # else:
+ # if x.shape[2] > self.freq_ratio * self.spec_size:
+ # if self.training:
+ # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
+ # x = self.reshape_wav2img(x)
+ # output_dict = self.forward_features(x)
+ # else:
+ # # Change: Hard code here
+ # overlap_size = (x.shape[2] - 1) // 4
+ # output_dicts = []
+ # crop_size = (x.shape[2] - 1) // 2
+ # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
+ # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
+ # tx = self.reshape_wav2img(tx)
+ # output_dicts.append(self.forward_features(tx))
+ # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
+ # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
+ # for d in output_dicts:
+ # clipwise_output += d["clipwise_output"]
+ # framewise_output += d["framewise_output"]
+ # clipwise_output = clipwise_output / len(output_dicts)
+ # framewise_output = framewise_output / len(output_dicts)
+ # output_dict = {
+ # 'framewise_output': framewise_output,
+ # 'clipwise_output': clipwise_output
+ # }
+ # else: # this part is typically used, and most easy one
+ # x = self.reshape_wav2img(x)
+ # output_dict = self.forward_features(x)
+ # x = self.head(x)
+
+ # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T
+
+ return output_dict
+
+
+def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"):
+ try:
+ assert audio_cfg.model_name in [
+ "tiny",
+ "base",
+ "large",
+ ], "model name for HTS-AT is wrong!"
+ if audio_cfg.model_name == "tiny":
+ model = HTSAT_Swin_Transformer(
+ spec_size=256,
+ patch_size=4,
+ patch_stride=(4, 4),
+ num_classes=audio_cfg.class_num,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=8,
+ config=audio_cfg,
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type,
+ )
+ elif audio_cfg.model_name == "base":
+ model = HTSAT_Swin_Transformer(
+ spec_size=256,
+ patch_size=4,
+ patch_stride=(4, 4),
+ num_classes=audio_cfg.class_num,
+ embed_dim=128,
+ depths=[2, 2, 12, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=8,
+ config=audio_cfg,
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type,
+ )
+ elif audio_cfg.model_name == "large":
+ model = HTSAT_Swin_Transformer(
+ spec_size=256,
+ patch_size=4,
+ patch_stride=(4, 4),
+ num_classes=audio_cfg.class_num,
+ embed_dim=256,
+ depths=[2, 2, 12, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=8,
+ config=audio_cfg,
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type,
+ )
+
+ return model
+ except:
+ raise RuntimeError(
+ f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
+ )
diff --git a/audioldm2/clap/open_clip/loss.py b/audioldm2/clap/open_clip/loss.py
new file mode 100755
index 0000000000000000000000000000000000000000..37faba58f3693d0659512ab1d6e19614fbda0675
--- /dev/null
+++ b/audioldm2/clap/open_clip/loss.py
@@ -0,0 +1,397 @@
+import torch
+import torch.distributed.nn
+from torch import distributed as dist, nn as nn
+from torch.nn import functional as F
+import numpy as np
+from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+
+def gather_features(
+ audio_features,
+ text_features,
+ audio_features_mlp=None,
+ text_features_mlp=None,
+ local_loss=False,
+ gather_with_grad=False,
+ rank=0,
+ world_size=1,
+ use_horovod=False,
+ mlp_loss=False,
+):
+ if use_horovod:
+ assert hvd is not None, "Please install horovod"
+ if gather_with_grad:
+ all_audio_features = hvd.allgather(audio_features)
+ all_text_features = hvd.allgather(text_features)
+ if mlp_loss:
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
+ else:
+ with torch.no_grad():
+ all_audio_features = hvd.allgather(audio_features)
+ all_text_features = hvd.allgather(text_features)
+ if mlp_loss:
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
+ if not local_loss:
+ # ensure grads for local rank when all_* features don't have a gradient
+ gathered_audio_features = list(
+ all_audio_features.chunk(world_size, dim=0)
+ )
+ gathered_text_features = list(
+ all_text_features.chunk(world_size, dim=0)
+ )
+ gathered_audio_features[rank] = audio_features
+ gathered_text_features[rank] = text_features
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
+ all_text_features = torch.cat(gathered_text_features, dim=0)
+ if mlp_loss:
+ gathered_audio_features_mlp = list(
+ all_audio_features_mlp.chunk(world_size, dim=0)
+ )
+ gathered_text_features_mlp = list(
+ all_text_features_mlp.chunk(world_size, dim=0)
+ )
+ gathered_audio_features_mlp[rank] = audio_features_mlp
+ gathered_text_features_mlp[rank] = text_features_mlp
+ all_audio_features_mlp = torch.cat(
+ gathered_audio_features_mlp, dim=0
+ )
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
+ else:
+ # We gather tensors from all gpus
+ if gather_with_grad:
+ all_audio_features = torch.cat(
+ torch.distributed.nn.all_gather(audio_features), dim=0
+ )
+ all_text_features = torch.cat(
+ torch.distributed.nn.all_gather(text_features), dim=0
+ )
+ if mlp_loss:
+ all_audio_features_mlp = torch.cat(
+ torch.distributed.nn.all_gather(audio_features_mlp), dim=0
+ )
+ all_text_features_mlp = torch.cat(
+ torch.distributed.nn.all_gather(text_features_mlp), dim=0
+ )
+ else:
+ gathered_audio_features = [
+ torch.zeros_like(audio_features) for _ in range(world_size)
+ ]
+ gathered_text_features = [
+ torch.zeros_like(text_features) for _ in range(world_size)
+ ]
+ dist.all_gather(gathered_audio_features, audio_features)
+ dist.all_gather(gathered_text_features, text_features)
+ if mlp_loss:
+ gathered_audio_features_mlp = [
+ torch.zeros_like(audio_features_mlp) for _ in range(world_size)
+ ]
+ gathered_text_features_mlp = [
+ torch.zeros_like(text_features_mlp) for _ in range(world_size)
+ ]
+ dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
+ dist.all_gather(gathered_text_features_mlp, text_features_mlp)
+ if not local_loss:
+ # ensure grads for local rank when all_* features don't have a gradient
+ gathered_audio_features[rank] = audio_features
+ gathered_text_features[rank] = text_features
+ if mlp_loss:
+ gathered_audio_features_mlp[rank] = audio_features_mlp
+ gathered_text_features_mlp[rank] = text_features_mlp
+
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
+ all_text_features = torch.cat(gathered_text_features, dim=0)
+ if mlp_loss:
+ all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
+ if mlp_loss:
+ return (
+ all_audio_features,
+ all_text_features,
+ all_audio_features_mlp,
+ all_text_features_mlp,
+ )
+ else:
+ return all_audio_features, all_text_features
+
+
+class ClipLoss(nn.Module):
+ def __init__(
+ self,
+ local_loss=False,
+ gather_with_grad=False,
+ cache_labels=False,
+ rank=0,
+ world_size=1,
+ use_horovod=False,
+ mlp_loss=False,
+ weight_loss_kappa=0,
+ ):
+ super().__init__()
+ self.local_loss = local_loss
+ self.gather_with_grad = gather_with_grad
+ self.cache_labels = cache_labels
+ self.rank = rank
+ self.world_size = world_size
+ self.use_horovod = use_horovod
+ self.mlp_loss = mlp_loss
+ self.weighted_loss = bool(weight_loss_kappa != 0)
+ self.weight_loss_kappa = weight_loss_kappa
+ # cache state
+ self.prev_num_logits = 0
+ self.labels = {}
+
+ def forward(
+ self,
+ audio_features,
+ text_features,
+ logit_scale_a,
+ logit_scale_t=None,
+ audio_features_mlp=None,
+ text_features_mlp=None,
+ ):
+ device = audio_features.device
+ if self.mlp_loss:
+ if self.world_size > 1:
+ (
+ all_audio_features,
+ all_text_features,
+ all_audio_features_mlp,
+ all_text_features_mlp,
+ ) = gather_features(
+ audio_features=audio_features,
+ text_features=text_features,
+ audio_features_mlp=audio_features_mlp,
+ text_features_mlp=text_features_mlp,
+ local_loss=self.local_loss,
+ gather_with_grad=self.gather_with_grad,
+ rank=self.rank,
+ world_size=self.world_size,
+ use_horovod=self.use_horovod,
+ mlp_loss=self.mlp_loss,
+ )
+ if self.local_loss:
+ a_logits_per_audio = (
+ logit_scale_a * audio_features @ all_text_features_mlp.T
+ )
+ a_logits_per_text = (
+ logit_scale_a * text_features_mlp @ all_audio_features.T
+ )
+ t_logits_per_audio = (
+ logit_scale_t * audio_features_mlp @ all_text_features.T
+ )
+ t_logits_per_text = (
+ logit_scale_t * text_features @ all_audio_features_mlp.T
+ )
+ else:
+ a_logits_per_audio = (
+ logit_scale_a * all_audio_features @ all_text_features_mlp.T
+ )
+ a_logits_per_text = a_logits_per_audio.T
+ t_logits_per_audio = (
+ logit_scale_t * all_audio_features_mlp @ all_text_features.T
+ )
+ t_logits_per_text = t_logits_per_audio.T
+ else:
+ a_logits_per_audio = (
+ logit_scale_a * audio_features @ text_features_mlp.T
+ )
+ a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
+ t_logits_per_audio = (
+ logit_scale_t * audio_features_mlp @ text_features.T
+ )
+ t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
+
+ # calculated ground-truth and cache if enabled
+ num_logits = a_logits_per_audio.shape[0]
+ if self.prev_num_logits != num_logits or device not in self.labels:
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
+ if self.world_size > 1 and self.local_loss:
+ labels = labels + num_logits * self.rank
+ if self.cache_labels:
+ self.labels[device] = labels
+ self.prev_num_logits = num_logits
+ else:
+ labels = self.labels[device]
+
+ if not self.weighted_loss:
+ total_loss = (
+ F.cross_entropy(a_logits_per_audio, labels)
+ + F.cross_entropy(a_logits_per_text, labels)
+ + F.cross_entropy(t_logits_per_audio, labels)
+ + F.cross_entropy(t_logits_per_text, labels)
+ ) / 4
+ else:
+ audio_weight = (audio_features @ audio_features.T).detach()
+ audio_weight = (
+ torch.exp(
+ torch.sum(audio_weight, axis=1)
+ / (self.weight_loss_kappa * len(audio_weight))
+ )
+ ).detach()
+ text_weight = (text_features @ text_features.T).detach()
+ text_weight = (
+ torch.exp(
+ torch.sum(text_weight, axis=1)
+ / (self.weight_loss_kappa * len(text_features))
+ )
+ ).detach()
+ total_loss = (
+ F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight)
+ + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight)
+ + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight)
+ + F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
+ ) / 4
+ else:
+ if self.world_size > 1:
+ all_audio_features, all_text_features = gather_features(
+ audio_features=audio_features,
+ text_features=text_features,
+ local_loss=self.local_loss,
+ gather_with_grad=self.gather_with_grad,
+ rank=self.rank,
+ world_size=self.world_size,
+ use_horovod=self.use_horovod,
+ mlp_loss=self.mlp_loss,
+ )
+
+ if self.local_loss:
+ logits_per_audio = (
+ logit_scale_a * audio_features @ all_text_features.T
+ )
+ logits_per_text = (
+ logit_scale_a * text_features @ all_audio_features.T
+ )
+ else:
+ logits_per_audio = (
+ logit_scale_a * all_audio_features @ all_text_features.T
+ )
+ logits_per_text = logits_per_audio.T
+ else:
+ logits_per_audio = logit_scale_a * audio_features @ text_features.T
+ logits_per_text = logit_scale_a * text_features @ audio_features.T
+
+ # calculated ground-truth and cache if enabled
+ num_logits = logits_per_audio.shape[0]
+ if self.prev_num_logits != num_logits or device not in self.labels:
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
+ if self.world_size > 1 and self.local_loss:
+ labels = labels + num_logits * self.rank
+ if self.cache_labels:
+ self.labels[device] = labels
+ self.prev_num_logits = num_logits
+ else:
+ labels = self.labels[device]
+ if not self.weighted_loss:
+ total_loss = (
+ F.cross_entropy(logits_per_audio, labels)
+ + F.cross_entropy(logits_per_text, labels)
+ ) / 2
+ else:
+ audio_weight = (all_audio_features @ all_audio_features.T).detach()
+ audio_weight = (
+ torch.exp(
+ torch.sum(audio_weight, axis=1)
+ / (self.weight_loss_kappa * len(all_audio_features))
+ )
+ ).detach()
+ text_weight = (all_text_features @ all_text_features.T).detach()
+ text_weight = (
+ torch.exp(
+ torch.sum(text_weight, axis=1)
+ / (self.weight_loss_kappa * len(all_text_features))
+ )
+ ).detach()
+ total_loss = (
+ F.cross_entropy(logits_per_audio, labels, weight=text_weight)
+ + F.cross_entropy(logits_per_text, labels, weight=audio_weight)
+ ) / 2
+ return total_loss
+
+
+def lp_gather_features(pred, target, world_size=1, use_horovod=False):
+ if use_horovod:
+ assert hvd is not None, "Please install horovod"
+ with torch.no_grad():
+ all_preds = hvd.allgather(pred)
+ all_targets = hvd.allgath(target)
+ else:
+ gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
+ gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
+
+ dist.all_gather(gathered_preds, pred)
+ dist.all_gather(gathered_targets, target)
+ all_preds = torch.cat(gathered_preds, dim=0)
+ all_targets = torch.cat(gathered_targets, dim=0)
+
+ return all_preds, all_targets
+
+
+def get_map(pred, target):
+ pred = torch.sigmoid(pred).numpy()
+ target = target.numpy()
+ return np.mean(average_precision_score(target, pred, average=None))
+
+
+def get_acc(pred, target):
+ pred = torch.argmax(pred, 1).numpy()
+ target = torch.argmax(target, 1).numpy()
+ return accuracy_score(target, pred)
+
+
+def get_mauc(pred, target):
+ pred = torch.sigmoid(pred).numpy()
+ target = target.numpy()
+ return np.mean(roc_auc_score(target, pred, average=None))
+
+
+class LPMetrics(object):
+ def __init__(self, metric_names=["map", "acc", "mauc"]):
+ self.metrics = []
+ for name in metric_names:
+ self.metrics.append(self.get_metric(name))
+ self.metric_names = metric_names
+
+ def get_metric(self, name):
+ if name == "map":
+ return get_map
+ elif name == "acc":
+ return get_acc
+ elif name == "mauc":
+ return get_mauc
+ else:
+ raise ValueError(f"the metric should be at least one of [map, acc, mauc]")
+
+ def evaluate_mertics(self, pred, target):
+ metric_dict = {}
+ for i in range(len(self.metric_names)):
+ metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
+ return metric_dict
+
+
+def calc_celoss(pred, target):
+ target = torch.argmax(target, 1).long()
+ return nn.CrossEntropyLoss()(pred, target)
+
+
+class LPLoss(nn.Module):
+ def __init__(self, loss_name):
+ super().__init__()
+ if loss_name == "bce":
+ self.loss_func = nn.BCEWithLogitsLoss()
+ elif loss_name == "ce":
+ self.loss_func = calc_celoss
+ elif loss_name == "mse":
+ self.loss_func = nn.MSELoss()
+ else:
+ raise ValueError(f"the loss func should be at least one of [bce, ce, mse]")
+
+ def forward(self, pred, target):
+ loss = self.loss_func(pred, target)
+ return loss
diff --git a/audioldm2/clap/open_clip/model.py b/audioldm2/clap/open_clip/model.py
new file mode 100755
index 0000000000000000000000000000000000000000..130fb582d016868d478e2d10e90d7fc0e7999078
--- /dev/null
+++ b/audioldm2/clap/open_clip/model.py
@@ -0,0 +1,931 @@
+""" CLAP Model
+
+Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+Adapted to the Audio Task.
+"""
+
+from collections import OrderedDict
+from dataclasses import dataclass
+from typing import Tuple, Union, Callable, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+import logging
+from .utils import freeze_batch_norm_2d
+
+from .pann_model import create_pann_model
+from .htsat import create_htsat_model
+from transformers import BertModel, RobertaModel, BartModel, RobertaConfig
+
+
+class MLPLayers(nn.Module):
+ def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
+ super(MLPLayers, self).__init__()
+ self.nonlin = nonlin
+ self.dropout = dropout
+
+ sequence = []
+ for u0, u1 in zip(units[:-1], units[1:]):
+ sequence.append(nn.Linear(u0, u1))
+ sequence.append(self.nonlin)
+ sequence.append(nn.Dropout(self.dropout))
+ sequence = sequence[:-2]
+
+ self.sequential = nn.Sequential(*sequence)
+
+ def forward(self, X):
+ X = self.sequential(X)
+ return X
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(
+ OrderedDict(
+ [
+ ("-1", nn.AvgPool2d(stride)),
+ (
+ "0",
+ nn.Conv2d(
+ inplanes,
+ planes * self.expansion,
+ 1,
+ stride=1,
+ bias=False,
+ ),
+ ),
+ ("1", nn.BatchNorm2d(planes * self.expansion)),
+ ]
+ )
+ )
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.relu(self.bn1(self.conv1(x)))
+ out = self.relu(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(
+ self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
+ )
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
+ 2, 0, 1
+ ) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x,
+ key=x,
+ value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat(
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
+ ),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False,
+ )
+
+ return x[0]
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.image_size = image_size
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(
+ 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
+ )
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.conv2 = nn.Conv2d(
+ width // 2, width // 2, kernel_size=3, padding=1, bias=False
+ )
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.avgpool = nn.AvgPool2d(2)
+ self.relu = nn.ReLU(inplace=True)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
+
+ self.init_parameters()
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def init_parameters(self):
+ if self.attnpool is not None:
+ std = self.attnpool.c_proj.in_features**-0.5
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
+
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
+ for name, param in resnet_block.named_parameters():
+ if name.endswith("bn3.weight"):
+ nn.init.zeros_(param)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ assert (
+ unlocked_groups == 0
+ ), "partial locking not currently supported for this model"
+ for param in self.parameters():
+ param.requires_grad = False
+ if freeze_bn_stats:
+ freeze_batch_norm_2d(self)
+
+ def stem(self, x):
+ for conv, bn in [
+ (self.conv1, self.bn1),
+ (self.conv2, self.bn2),
+ (self.conv3, self.bn3),
+ ]:
+ x = self.relu(bn(conv(x)))
+ x = self.avgpool(x)
+ return x
+
+ def forward(self, x):
+ x = self.stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ return x.to(orig_type)
+
+
+class QuickGELU(nn.Module):
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(
+ OrderedDict(
+ [
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", act_layer()),
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
+ ]
+ )
+ )
+ self.ln_2 = LayerNorm(d_model)
+
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
+ ):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.ModuleList(
+ [
+ ResidualAttentionBlock(width, heads, act_layer=act_layer)
+ for _ in range(layers)
+ ]
+ )
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ for r in self.resblocks:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+
+class VisualTransformer(nn.Module):
+ def __init__(
+ self,
+ image_size: int,
+ patch_size: int,
+ width: int,
+ layers: int,
+ heads: int,
+ output_dim: int,
+ act_layer: Callable = nn.GELU,
+ ):
+ super().__init__()
+ self.image_size = image_size
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(
+ in_channels=3,
+ out_channels=width,
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias=False,
+ )
+
+ scale = width**-0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(
+ scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
+ )
+ self.ln_pre = LayerNorm(width)
+
+ self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
+
+ self.ln_post = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ assert (
+ unlocked_groups == 0
+ ), "partial locking not currently supported for this model"
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, x: torch.Tensor):
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat(
+ [
+ self.class_embedding.to(x.dtype)
+ + torch.zeros(
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
+ ),
+ x,
+ ],
+ dim=1,
+ ) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_branch(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ x = self.ln_post(x[:, 0, :])
+
+ if self.proj is not None:
+ x = x @ self.proj
+
+ return x
+
+
+@dataclass
+class CLAPVisionCfg:
+ layers: Union[Tuple[int, int, int, int], int] = 12
+ width: int = 768
+ patch_size: int = 16
+ image_size: Union[Tuple[int, int], int] = 224
+ timm_model_name: str = (
+ None # a valid model name overrides layers, width, patch_size
+ )
+ timm_model_pretrained: bool = (
+ False # use (imagenet) pretrained weights for named model
+ )
+ timm_pool: str = (
+ "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
+ )
+ timm_proj: str = (
+ "linear" # linear projection for timm model output ('linear', 'mlp', '')
+ )
+
+
+# Audio Config Class
+@dataclass
+class CLAPAudioCfp:
+ model_type: str = "PANN"
+ model_name: str = "Cnn14"
+ sample_rate: int = 48000
+ # Param
+ audio_length: int = 1024
+ window_size: int = 1024
+ hop_size: int = 1024
+ fmin: int = 50
+ fmax: int = 14000
+ class_num: int = 527
+ mel_bins: int = 64
+ clip_samples: int = 480000
+
+
+@dataclass
+class CLAPTextCfg:
+ context_length: int
+ vocab_size: int
+ width: int
+ heads: int
+ layers: int
+ model_type: str
+
+
+class CLAP(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ audio_cfg: CLAPAudioCfp,
+ text_cfg: CLAPTextCfg,
+ quick_gelu: bool = False,
+ enable_fusion: bool = False,
+ fusion_type: str = "None",
+ joint_embed_shape: int = 512,
+ mlp_act: str = "relu",
+ ):
+ super().__init__()
+ if isinstance(audio_cfg, dict):
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
+ if isinstance(text_cfg, dict):
+ text_cfg = CLAPTextCfg(**text_cfg)
+
+ self.audio_cfg = audio_cfg
+ self.text_cfg = text_cfg
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+ self.joint_embed_shape = joint_embed_shape
+ self.mlp_act = mlp_act
+
+ self.context_length = text_cfg.context_length
+
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
+ # memory efficient in recent PyTorch releases (>= 1.10).
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
+ act_layer = QuickGELU if quick_gelu else nn.GELU
+
+ if mlp_act == "relu":
+ mlp_act_layer = nn.ReLU()
+ elif mlp_act == "gelu":
+ mlp_act_layer = nn.GELU()
+ else:
+ raise NotImplementedError
+
+ # audio branch
+ # audio branch parameters
+ if audio_cfg.model_type == "PANN":
+ self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
+ elif audio_cfg.model_type == "HTSAT":
+ self.audio_branch = create_htsat_model(
+ audio_cfg, enable_fusion, fusion_type
+ )
+ else:
+ logging.error(f"Model config for {audio_cfg.model_type} not found")
+ raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
+
+ # text branch
+ # text branch parameters
+ if text_cfg.model_type == "transformer":
+ self.text_branch = Transformer(
+ width=text_cfg.width,
+ layers=text_cfg.layers,
+ heads=text_cfg.heads,
+ act_layer=act_layer,
+ )
+ self.vocab_size = text_cfg.vocab_size
+ self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
+ self.positional_embedding = nn.Parameter(
+ torch.empty(self.context_length, text_cfg.width)
+ )
+ self.ln_final = LayerNorm(text_cfg.width)
+ self.text_transform = MLPLayers(
+ units=[
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ ],
+ dropout=0.1,
+ )
+ self.text_projection = nn.Sequential(
+ nn.Linear(text_cfg.width, self.joint_embed_shape),
+ mlp_act_layer,
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
+ )
+ elif text_cfg.model_type == "bert":
+ self.text_branch = BertModel.from_pretrained("bert-base-uncased")
+ self.text_transform = MLPLayers(
+ units=[
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ ],
+ dropout=0.1,
+ )
+ self.text_projection = nn.Sequential(
+ nn.Linear(768, self.joint_embed_shape),
+ mlp_act_layer,
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
+ )
+ elif text_cfg.model_type == "roberta":
+ self.text_branch = RobertaModel(
+ RobertaConfig.from_pretrained("roberta-base")
+ )
+ self.text_transform = MLPLayers(
+ units=[
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ ],
+ dropout=0.1,
+ )
+ self.text_projection = nn.Sequential(
+ nn.Linear(768, self.joint_embed_shape),
+ mlp_act_layer,
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
+ )
+ elif text_cfg.model_type == "bart":
+ self.text_branch = BartModel.from_pretrained("facebook/bart-base")
+ self.text_transform = MLPLayers(
+ units=[
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ ],
+ dropout=0.1,
+ )
+ self.text_projection = nn.Sequential(
+ nn.Linear(768, self.joint_embed_shape),
+ mlp_act_layer,
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
+ )
+ else:
+ logging.error(f"Model config for {text_cfg.model_type} not found")
+ raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
+ self.text_branch_type = text_cfg.model_type
+ # text branch parameters
+
+ # audio branch parameters
+ self.audio_transform = MLPLayers(
+ units=[
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ ],
+ dropout=0.1,
+ )
+
+ # below here is text branch parameters
+
+ # ============================================================================================================
+ self.audio_projection = nn.Sequential(
+ nn.Linear(embed_dim, self.joint_embed_shape),
+ mlp_act_layer,
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
+ )
+
+ self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+ self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+ self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
+
+ self.init_text_branch_parameters()
+
+ def init_text_branch_parameters(self):
+ if self.text_branch_type == "transformer":
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+ proj_std = (self.text_branch.width**-0.5) * (
+ (2 * self.text_branch.layers) ** -0.5
+ )
+ attn_std = self.text_branch.width**-0.5
+ fc_std = (2 * self.text_branch.width) ** -0.5
+ for block in self.text_branch.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+ if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
+ self.text_branch.embeddings.word_embeddings.weight.shape[-1]
+ elif self.text_branch_type == "bart":
+ self.text_branch.shared.weight.shape[-1]
+ else:
+ self.text_branch.width
+ nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
+ nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
+
+ # deprecated
+ # if hasattr(self.visual, 'init_parameters'):
+ # self.visual.init_parameters()
+
+ # if self.text_projection is not None:
+ # nn.init.normal_(self.text_projection, std=width**-0.5)
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ def encode_audio(self, audio, device):
+ return self.audio_branch(
+ audio, mixup_lambda=None, device=device
+ ) # mix lambda needs to add
+
+ # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
+ # tmp = {}
+ # for k in x[0].keys():
+ # tmp[k] = []
+ # for i in range(len(x)):
+ # tmp[k].append(x[i][k][:77])
+ # for k in x[0].keys():
+ # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
+ # return tmp
+
+ def encode_text(self, text, device):
+ if self.text_branch_type == "transformer":
+ text = text.to(device=device, non_blocking=True)
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_branch(x, attn_mask=self.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x)
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
+ elif self.text_branch_type == "bert":
+ # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
+ # text = BatchEncoding(text)
+ x = self.text_branch(
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
+ attention_mask=text["attention_mask"].to(
+ device=device, non_blocking=True
+ ),
+ token_type_ids=text["token_type_ids"].to(
+ device=device, non_blocking=True
+ ),
+ )["pooler_output"]
+ x = self.text_projection(x)
+ elif self.text_branch_type == "roberta":
+ x = self.text_branch(
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
+ attention_mask=text["attention_mask"].to(
+ device=device, non_blocking=True
+ ),
+ )["pooler_output"]
+ x = self.text_projection(x)
+ elif self.text_branch_type == "bart":
+ x = torch.mean(
+ self.text_branch(
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
+ attention_mask=text["attention_mask"].to(
+ device=device, non_blocking=True
+ ),
+ )["encoder_last_hidden_state"],
+ axis=1,
+ )
+ x = self.text_projection(x)
+ else:
+ logging.error(f"Model type {self.text_branch_type} not found")
+ raise RuntimeError(f"Model type {self.text_branch_type} not found.")
+ return x
+
+ def forward(self, audio, text, device=None):
+ """Forward audio and text into the CLAP
+
+ Parameters
+ ----------
+ audio: torch.Tensor (batch_size, audio_length)
+ the time-domain audio input / the batch of mel_spec and longer list.
+ text: torch.Tensor () // need to add
+ the text token input
+ """
+ if device is None:
+ if audio is not None:
+ device = audio.device
+ elif text is not None:
+ device = text.device
+ if audio is None and text is None:
+ # a hack to get the logit scale
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
+ elif audio is None:
+ return self.encode_text(text, device=device)
+ elif text is None:
+ return self.audio_projection(
+ self.encode_audio(audio, device=device)["embedding"]
+ )
+ audio_features = self.audio_projection(
+ self.encode_audio(audio, device=device)["embedding"]
+ )
+ audio_features = F.normalize(audio_features, dim=-1)
+
+ text_features = self.encode_text(text, device=device)
+ # print("text_features", text_features)
+ # print("text_features.shape", text_features.shape)
+ # print("text_features.type", type(text_features))
+ text_features = F.normalize(text_features, dim=-1)
+
+ audio_features_mlp = self.audio_transform(audio_features)
+ text_features_mlp = self.text_transform(text_features)
+ # Four outputs: audio features (basic & MLP), text features (basic & MLP)
+ return (
+ audio_features,
+ text_features,
+ audio_features_mlp,
+ text_features_mlp,
+ self.logit_scale_a.exp(),
+ self.logit_scale_t.exp(),
+ )
+
+ def get_logit_scale(self):
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
+
+ def get_text_embedding(self, data):
+ """Get the text embedding from the model
+
+ Parameters
+ ----------
+ data: torch.Tensor
+ a tensor of text embedding
+
+ Returns
+ ----------
+ text_embed: torch.Tensor
+ a tensor of text_embeds (N, D)
+
+ """
+ device = next(self.parameters()).device
+ for k in data:
+ data[k] = data[k].to(device)
+ text_embeds = self.encode_text(data, device=device)
+ text_embeds = F.normalize(text_embeds, dim=-1)
+
+ return text_embeds
+
+ def get_audio_embedding(self, data):
+ """Get the audio embedding from the model
+
+ Parameters
+ ----------
+ data: a list of dict
+ the audio input dict list from 'get_audio_feature' method
+
+ Returns
+ ----------
+ audio_embed: torch.Tensor
+ a tensor of audio_embeds (N, D)
+
+ """
+ device = next(self.parameters()).device
+ # input_dict = {}
+ # keys = data[0].keys()
+ # for k in keys:
+ # input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(
+ # device
+ # )
+ audio_embeds = self.audio_projection(
+ self.encode_audio(data, device=device)["embedding"]
+ )
+ audio_embeds = F.normalize(audio_embeds, dim=-1)
+
+ return audio_embeds
+
+ def audio_infer(self, audio, hopsize=None, device=None):
+ """Forward one audio and produce the audio embedding
+
+ Parameters
+ ----------
+ audio: (audio_length)
+ the time-domain audio input, notice that it must be only one input
+ hopsize: int
+ the overlap hopsize as the sliding window
+
+ Returns
+ ----------
+ output_dict: {
+ key: [n, (embedding_shape)] if "HTS-AT"
+ or
+ key: [(embedding_shape)] if "PANN"
+ }
+ the list of key values of the audio branch
+
+ """
+
+ assert not self.training, "the inference mode must be run at eval stage"
+ output_dict = {}
+ # PANN
+ if self.audio_cfg.model_type == "PANN":
+ audio_input = audio.unsqueeze(dim=0)
+ output_dict[key] = self.encode_audio(audio_input, device=device)[
+ key
+ ].squeeze(dim=0)
+ elif self.audio_cfg.model_type == "HTSAT":
+ # repeat
+ audio_len = len(audio)
+ k = self.audio_cfg.clip_samples // audio_len
+ if k > 1:
+ audio = audio.repeat(k)
+ audio_len = len(audio)
+
+ if hopsize is None:
+ hopsize = min(hopsize, audio_len)
+
+ if audio_len > self.audio_cfg.clip_samples:
+ audio_input = [
+ audio[pos : pos + self.audio_cfg.clip_samples].clone()
+ for pos in range(
+ 0, audio_len - self.audio_cfg.clip_samples, hopsize
+ )
+ ]
+ audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
+ audio_input = torch.stack(audio_input)
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key]
+ else:
+ audio_input = audio.unsqueeze(dim=0)
+ output_dict[key] = self.encode_audio(audio_input, device=device)[
+ key
+ ].squeeze(dim=0)
+
+ return output_dict
+
+
+def convert_weights_to_fp16(model: nn.Module):
+ """Convert applicable model parameters to fp16"""
+
+ def _convert_weights_to_fp16(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+ if isinstance(l, nn.MultiheadAttention):
+ for attr in [
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
+ "in_proj_bias",
+ "bias_k",
+ "bias_v",
+ ]:
+ tensor = getattr(l, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.half()
+
+ for name in ["text_projection", "proj"]:
+ if hasattr(l, name):
+ attr = getattr(l, name)
+ if attr is not None:
+ attr.data = attr.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+# Ignore the state dict of the vision part
+def build_model_from_openai_state_dict(
+ state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None"
+):
+ embed_dim = model_cfg["embed_dim"]
+ audio_cfg = model_cfg["audio_cfg"]
+ text_cfg = model_cfg["text_cfg"]
+ state_dict["positional_embedding"].shape[0]
+ state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_width // 64
+ transformer_layers = len(
+ set(
+ k.split(".")[2]
+ for k in state_dict
+ if k.startswith(f"transformer.resblocks")
+ )
+ )
+
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
+ text_cfg = CLAPTextCfg(**text_cfg)
+
+ model = CLAP(
+ embed_dim,
+ audio_cfg=audio_cfg,
+ text_cfg=text_cfg,
+ quick_gelu=True, # OpenAI models were trained with QuickGELU
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type,
+ )
+ state_dict["logit_scale_a"] = state_dict["logit_scale"]
+ state_dict["logit_scale_t"] = state_dict["logit_scale"]
+ pop_keys = list(state_dict.keys())[::]
+ # pop the visual branch saved weights
+ for key in pop_keys:
+ if key.startswith("visual."):
+ state_dict.pop(key, None)
+
+ for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
+ state_dict.pop(key, None)
+
+ # not use fp16
+ # convert_weights_to_fp16(model)
+ model.load_state_dict(state_dict, strict=False)
+ return model.eval()
+
+
+def trace_model(model, batch_size=256, device=torch.device("cpu")):
+ model.eval()
+ audio_length = model.audio_cfg.audio_length
+ example_audio = torch.ones((batch_size, audio_length), device=device)
+ example_text = torch.zeros(
+ (batch_size, model.context_length), dtype=torch.int, device=device
+ )
+ model = torch.jit.trace_module(
+ model,
+ inputs=dict(
+ forward=(example_audio, example_text),
+ encode_text=(example_text,),
+ encode_image=(example_audio,),
+ ),
+ )
+ model.audio_cfg.audio_length = audio_length # Question: what does this do?
+ return model
diff --git a/audioldm2/clap/open_clip/model_configs/HTSAT-base.json b/audioldm2/clap/open_clip/model_configs/HTSAT-base.json
new file mode 100755
index 0000000000000000000000000000000000000000..6cef625a89daf4431f1c9f72e10bc9640eef2ba8
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/HTSAT-base.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 1024,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "HTSAT",
+ "model_name": "base"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/HTSAT-large.json b/audioldm2/clap/open_clip/model_configs/HTSAT-large.json
new file mode 100755
index 0000000000000000000000000000000000000000..699cdb1b16855582606551e4196b24aba2ffd871
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/HTSAT-large.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "HTSAT",
+ "model_name": "large"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json b/audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json
new file mode 100755
index 0000000000000000000000000000000000000000..73e42990fe8361a0df502e7f93d29f19f58c9ecb
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 768,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1536,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "HTSAT",
+ "model_name": "tiny"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json b/audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json
new file mode 100755
index 0000000000000000000000000000000000000000..a6e7821163d9afa81c27345a1e472475b92af169
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 768,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "HTSAT",
+ "model_name": "tiny"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/PANN-10.json b/audioldm2/clap/open_clip/model_configs/PANN-10.json
new file mode 100755
index 0000000000000000000000000000000000000000..954ddf62921aed7dde9c37ffffec98a2e96a4ee7
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/PANN-10.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 1024,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn10"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json b/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json
new file mode 100755
index 0000000000000000000000000000000000000000..b7989bc0cd95d0d39049b7524eba508b3e386439
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 18000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn14"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json b/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json
new file mode 100755
index 0000000000000000000000000000000000000000..56bdb56bedc304ffa52d8bf5988cea2c1d82d14e
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 960000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 360,
+ "fmin": 50,
+ "fmax": 8000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn14"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json b/audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json
new file mode 100755
index 0000000000000000000000000000000000000000..5756e3bebc97cc985f512cb081930fee4e49bec1
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn14"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 4
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json b/audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json
new file mode 100755
index 0000000000000000000000000000000000000000..5a9e7e208b661619d5e26625e849da1adda8a475
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1536,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn14"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14.json b/audioldm2/clap/open_clip/model_configs/PANN-14.json
new file mode 100755
index 0000000000000000000000000000000000000000..39a5134cde1d8c50f4758377c952ef22f07bab41
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/PANN-14.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn14"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/PANN-6.json b/audioldm2/clap/open_clip/model_configs/PANN-6.json
new file mode 100755
index 0000000000000000000000000000000000000000..21ebc344326de260c386ba77e0ad63cf9b04febf
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/PANN-6.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 512,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn6"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json b/audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json
new file mode 100755
index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json
@@ -0,0 +1,22 @@
+{
+ "embed_dim": 512,
+ "quick_gelu": true,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": [
+ 3,
+ 4,
+ 23,
+ 3
+ ],
+ "width": 64,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/RN101.json b/audioldm2/clap/open_clip/model_configs/RN101.json
new file mode 100755
index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/RN101.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": [
+ 3,
+ 4,
+ 23,
+ 3
+ ],
+ "width": 64,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json b/audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json
new file mode 100755
index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json
@@ -0,0 +1,22 @@
+{
+ "embed_dim": 1024,
+ "quick_gelu": true,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": [
+ 3,
+ 4,
+ 6,
+ 3
+ ],
+ "width": 64,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
diff --git a/audioldm2/clap/open_clip/model_configs/RN50.json b/audioldm2/clap/open_clip/model_configs/RN50.json
new file mode 100755
index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/RN50.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 1024,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": [
+ 3,
+ 4,
+ 6,
+ 3
+ ],
+ "width": 64,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/RN50x16.json b/audioldm2/clap/open_clip/model_configs/RN50x16.json
new file mode 100755
index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/RN50x16.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 768,
+ "vision_cfg": {
+ "image_size": 384,
+ "layers": [
+ 6,
+ 8,
+ 18,
+ 8
+ ],
+ "width": 96,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 768,
+ "heads": 12,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/RN50x4.json b/audioldm2/clap/open_clip/model_configs/RN50x4.json
new file mode 100755
index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/RN50x4.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 640,
+ "vision_cfg": {
+ "image_size": 288,
+ "layers": [
+ 4,
+ 6,
+ 10,
+ 6
+ ],
+ "width": 80,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 640,
+ "heads": 10,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/ViT-B-16.json b/audioldm2/clap/open_clip/model_configs/ViT-B-16.json
new file mode 100755
index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/ViT-B-16.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 768,
+ "patch_size": 16
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json b/audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json
new file mode 100755
index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 512,
+ "quick_gelu": true,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 768,
+ "patch_size": 32
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/ViT-B-32.json b/audioldm2/clap/open_clip/model_configs/ViT-B-32.json
new file mode 100755
index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/ViT-B-32.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 768,
+ "patch_size": 32
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/ViT-L-14.json b/audioldm2/clap/open_clip/model_configs/ViT-L-14.json
new file mode 100755
index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/ViT-L-14.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 768,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 24,
+ "width": 1024,
+ "patch_size": 14
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 768,
+ "heads": 12,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/openai.py b/audioldm2/clap/open_clip/openai.py
new file mode 100755
index 0000000000000000000000000000000000000000..3f4eb8b55fe960e1792b3da804b60b3d8f70fe26
--- /dev/null
+++ b/audioldm2/clap/open_clip/openai.py
@@ -0,0 +1,156 @@
+""" OpenAI pretrained model functions
+
+Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+
+import os
+import warnings
+from typing import Union, List
+
+import torch
+
+from .model import build_model_from_openai_state_dict
+from .pretrained import (
+ get_pretrained_url,
+ list_pretrained_tag_models,
+ download_pretrained,
+)
+
+__all__ = ["list_openai_models", "load_openai_model"]
+
+
+def list_openai_models() -> List[str]:
+ """Returns the names of available CLIP models"""
+ return list_pretrained_tag_models("openai")
+
+
+def load_openai_model(
+ name: str,
+ model_cfg,
+ device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
+ jit=True,
+ cache_dir=os.path.expanduser("~/.cache/clip"),
+ enable_fusion: bool = False,
+ fusion_type: str = "None",
+):
+ """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
+
+ Parameters
+ ----------
+ name : str
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
+ device : Union[str, torch.device]
+ The device to put the loaded model
+ jit : bool
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
+
+ Returns
+ -------
+ model : torch.nn.Module
+ The CLAP model
+ preprocess : Callable[[PIL.Image], torch.Tensor]
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+ """
+ if get_pretrained_url(name, "openai"):
+ model_path = download_pretrained(
+ get_pretrained_url(name, "openai"), root=cache_dir
+ )
+ elif os.path.isfile(name):
+ model_path = name
+ else:
+ raise RuntimeError(
+ f"Model {name} not found; available models = {list_openai_models()}"
+ )
+
+ try:
+ # loading JIT archive
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+ state_dict = None
+ except RuntimeError:
+ # loading saved state dict
+ if jit:
+ warnings.warn(
+ f"File {model_path} is not a JIT archive. Loading as a state dict instead"
+ )
+ jit = False
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ if not jit:
+ try:
+ model = build_model_from_openai_state_dict(
+ state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type
+ ).to(device)
+ except KeyError:
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
+ model = build_model_from_openai_state_dict(
+ sd, model_cfg, enable_fusion, fusion_type
+ ).to(device)
+
+ if str(device) == "cpu":
+ model.float()
+ return model
+
+ # patch the device names
+ device_holder = torch.jit.trace(
+ lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
+ )
+ device_node = [
+ n
+ for n in device_holder.graph.findAllNodes("prim::Constant")
+ if "Device" in repr(n)
+ ][-1]
+
+ def patch_device(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("prim::Constant"):
+ if "value" in node.attributeNames() and str(node["value"]).startswith(
+ "cuda"
+ ):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_audio)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 on CPU
+ if str(device) == "cpu":
+ float_holder = torch.jit.trace(
+ lambda: torch.ones([]).float(), example_inputs=[]
+ )
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("aten::to"):
+ inputs = list(node.inputs())
+ for i in [
+ 1,
+ 2,
+ ]: # dtype can be the second or third argument to aten::to()
+ if inputs[i].node()["value"] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_audio)
+ patch_float(model.encode_text)
+ model.float()
+
+ model.audio_branch.audio_length = model.audio_cfg.audio_length
+ return model
diff --git a/audioldm2/clap/open_clip/pann_model.py b/audioldm2/clap/open_clip/pann_model.py
new file mode 100755
index 0000000000000000000000000000000000000000..e9fab8e03cdca370c141a9e321e98d256e79fb27
--- /dev/null
+++ b/audioldm2/clap/open_clip/pann_model.py
@@ -0,0 +1,697 @@
+# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
+# Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn
+# Some layers are re-designed for CLAP
+import os
+
+os.environ["NUMBA_CACHE_DIR"] = "/tmp/"
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchlibrosa.stft import Spectrogram, LogmelFilterBank
+from torchlibrosa.augmentation import SpecAugmentation
+
+from .utils import do_mixup, interpolate
+from .feature_fusion import iAFF, AFF, DAF
+
+
+def init_layer(layer):
+ """Initialize a Linear or Convolutional layer."""
+ nn.init.xavier_uniform_(layer.weight)
+
+ if hasattr(layer, "bias"):
+ if layer.bias is not None:
+ layer.bias.data.fill_(0.0)
+
+
+def init_bn(bn):
+ """Initialize a Batchnorm layer."""
+ bn.bias.data.fill_(0.0)
+ bn.weight.data.fill_(1.0)
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super(ConvBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding=(1, 1),
+ bias=False,
+ )
+
+ self.conv2 = nn.Conv2d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding=(1, 1),
+ bias=False,
+ )
+
+ self.bn1 = nn.BatchNorm2d(out_channels)
+ self.bn2 = nn.BatchNorm2d(out_channels)
+
+ self.init_weight()
+
+ def init_weight(self):
+ init_layer(self.conv1)
+ init_layer(self.conv2)
+ init_bn(self.bn1)
+ init_bn(self.bn2)
+
+ def forward(self, input, pool_size=(2, 2), pool_type="avg"):
+ x = input
+ x = F.relu_(self.bn1(self.conv1(x)))
+ x = F.relu_(self.bn2(self.conv2(x)))
+ if pool_type == "max":
+ x = F.max_pool2d(x, kernel_size=pool_size)
+ elif pool_type == "avg":
+ x = F.avg_pool2d(x, kernel_size=pool_size)
+ elif pool_type == "avg+max":
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
+ x = x1 + x2
+ else:
+ raise Exception("Incorrect argument!")
+
+ return x
+
+
+class ConvBlock5x5(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super(ConvBlock5x5, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(5, 5),
+ stride=(1, 1),
+ padding=(2, 2),
+ bias=False,
+ )
+
+ self.bn1 = nn.BatchNorm2d(out_channels)
+
+ self.init_weight()
+
+ def init_weight(self):
+ init_layer(self.conv1)
+ init_bn(self.bn1)
+
+ def forward(self, input, pool_size=(2, 2), pool_type="avg"):
+ x = input
+ x = F.relu_(self.bn1(self.conv1(x)))
+ if pool_type == "max":
+ x = F.max_pool2d(x, kernel_size=pool_size)
+ elif pool_type == "avg":
+ x = F.avg_pool2d(x, kernel_size=pool_size)
+ elif pool_type == "avg+max":
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
+ x = x1 + x2
+ else:
+ raise Exception("Incorrect argument!")
+
+ return x
+
+
+class AttBlock(nn.Module):
+ def __init__(self, n_in, n_out, activation="linear", temperature=1.0):
+ super(AttBlock, self).__init__()
+
+ self.activation = activation
+ self.temperature = temperature
+ self.att = nn.Conv1d(
+ in_channels=n_in,
+ out_channels=n_out,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ )
+ self.cla = nn.Conv1d(
+ in_channels=n_in,
+ out_channels=n_out,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ )
+
+ self.bn_att = nn.BatchNorm1d(n_out)
+ self.init_weights()
+
+ def init_weights(self):
+ init_layer(self.att)
+ init_layer(self.cla)
+ init_bn(self.bn_att)
+
+ def forward(self, x):
+ # x: (n_samples, n_in, n_time)
+ norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
+ cla = self.nonlinear_transform(self.cla(x))
+ x = torch.sum(norm_att * cla, dim=2)
+ return x, norm_att, cla
+
+ def nonlinear_transform(self, x):
+ if self.activation == "linear":
+ return x
+ elif self.activation == "sigmoid":
+ return torch.sigmoid(x)
+
+
+class Cnn14(nn.Module):
+ def __init__(
+ self,
+ sample_rate,
+ window_size,
+ hop_size,
+ mel_bins,
+ fmin,
+ fmax,
+ classes_num,
+ enable_fusion=False,
+ fusion_type="None",
+ ):
+ super(Cnn14, self).__init__()
+
+ window = "hann"
+ center = True
+ pad_mode = "reflect"
+ ref = 1.0
+ amin = 1e-10
+ top_db = None
+
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+
+ # Spectrogram extractor
+ self.spectrogram_extractor = Spectrogram(
+ n_fft=window_size,
+ hop_length=hop_size,
+ win_length=window_size,
+ window=window,
+ center=center,
+ pad_mode=pad_mode,
+ freeze_parameters=True,
+ )
+
+ # Logmel feature extractor
+ self.logmel_extractor = LogmelFilterBank(
+ sr=sample_rate,
+ n_fft=window_size,
+ n_mels=mel_bins,
+ fmin=fmin,
+ fmax=fmax,
+ ref=ref,
+ amin=amin,
+ top_db=top_db,
+ freeze_parameters=True,
+ )
+
+ # Spec augmenter
+ self.spec_augmenter = SpecAugmentation(
+ time_drop_width=64,
+ time_stripes_num=2,
+ freq_drop_width=8,
+ freq_stripes_num=2,
+ )
+
+ self.bn0 = nn.BatchNorm2d(64)
+
+ if (self.enable_fusion) and (self.fusion_type == "channel_map"):
+ self.conv_block1 = ConvBlock(in_channels=4, out_channels=64)
+ else:
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
+
+ self.fc1 = nn.Linear(2048, 2048, bias=True)
+ self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
+
+ if (self.enable_fusion) and (
+ self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
+ ):
+ self.mel_conv1d = nn.Sequential(
+ nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
+ nn.BatchNorm1d(64), # No Relu
+ )
+ if self.fusion_type == "daf_1d":
+ self.fusion_model = DAF()
+ elif self.fusion_type == "aff_1d":
+ self.fusion_model = AFF(channels=64, type="1D")
+ elif self.fusion_type == "iaff_1d":
+ self.fusion_model = iAFF(channels=64, type="1D")
+
+ if (self.enable_fusion) and (
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
+ ):
+ self.mel_conv2d = nn.Sequential(
+ nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)),
+ nn.BatchNorm2d(64),
+ nn.ReLU(inplace=True),
+ )
+
+ if self.fusion_type == "daf_2d":
+ self.fusion_model = DAF()
+ elif self.fusion_type == "aff_2d":
+ self.fusion_model = AFF(channels=64, type="2D")
+ elif self.fusion_type == "iaff_2d":
+ self.fusion_model = iAFF(channels=64, type="2D")
+ self.init_weight()
+
+ def init_weight(self):
+ init_bn(self.bn0)
+ init_layer(self.fc1)
+ init_layer(self.fc_audioset)
+
+ def forward(self, input, mixup_lambda=None, device=None):
+ """
+ Input: (batch_size, data_length)"""
+
+ if self.enable_fusion and input["longer"].sum() == 0:
+ # if no audio is longer than 10s, then randomly select one audio to be longer
+ input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True
+
+ if not self.enable_fusion:
+ x = self.spectrogram_extractor(
+ input["waveform"].to(device=device, non_blocking=True)
+ ) # (batch_size, 1, time_steps, freq_bins)
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
+
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+ else:
+ longer_list = input["longer"].to(device=device, non_blocking=True)
+ x = input["mel_fusion"].to(device=device, non_blocking=True)
+ longer_list_idx = torch.where(longer_list)[0]
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+ if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
+ new_x = x[:, 0:1, :, :].clone().contiguous()
+ # local processing
+ if len(longer_list_idx) > 0:
+ fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
+ FB, FC, FT, FF = fusion_x_local.size()
+ fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
+ fusion_x_local = torch.permute(
+ fusion_x_local, (0, 2, 1)
+ ).contiguous()
+ fusion_x_local = self.mel_conv1d(fusion_x_local)
+ fusion_x_local = fusion_x_local.view(
+ FB, FC, FF, fusion_x_local.size(-1)
+ )
+ fusion_x_local = (
+ torch.permute(fusion_x_local, (0, 2, 1, 3))
+ .contiguous()
+ .flatten(2)
+ )
+ if fusion_x_local.size(-1) < FT:
+ fusion_x_local = torch.cat(
+ [
+ fusion_x_local,
+ torch.zeros(
+ (FB, FF, FT - fusion_x_local.size(-1)),
+ device=device,
+ ),
+ ],
+ dim=-1,
+ )
+ else:
+ fusion_x_local = fusion_x_local[:, :, :FT]
+ # 1D fusion
+ new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
+ new_x[longer_list_idx] = self.fusion_model(
+ new_x[longer_list_idx], fusion_x_local
+ )
+ x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
+ else:
+ x = new_x
+ elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
+ x = x # no change
+
+ if self.training:
+ x = self.spec_augmenter(x)
+ # Mixup on spectrogram
+ if self.training and mixup_lambda is not None:
+ x = do_mixup(x, mixup_lambda)
+ if (self.enable_fusion) and (
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
+ ):
+ global_x = x[:, 0:1, :, :]
+
+ # global processing
+ B, C, H, W = global_x.shape
+ global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg")
+ if len(longer_list_idx) > 0:
+ local_x = x[longer_list_idx, 1:, :, :].contiguous()
+ TH = global_x.size(-2)
+ # local processing
+ B, C, H, W = local_x.shape
+ local_x = local_x.view(B * C, 1, H, W)
+ local_x = self.mel_conv2d(local_x)
+ local_x = local_x.view(
+ B, C, local_x.size(1), local_x.size(2), local_x.size(3)
+ )
+ local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3)
+ TB, TC, _, TW = local_x.size()
+ if local_x.size(-2) < TH:
+ local_x = torch.cat(
+ [
+ local_x,
+ torch.zeros(
+ (TB, TC, TH - local_x.size(-2), TW),
+ device=global_x.device,
+ ),
+ ],
+ dim=-2,
+ )
+ else:
+ local_x = local_x[:, :, :TH, :]
+
+ global_x[longer_list_idx] = self.fusion_model(
+ global_x[longer_list_idx], local_x
+ )
+ x = global_x
+ else:
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
+
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = torch.mean(x, dim=3)
+
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x = latent_x1 + latent_x2
+ latent_x = latent_x.transpose(1, 2)
+ latent_x = F.relu_(self.fc1(latent_x))
+ latent_output = interpolate(latent_x, 32)
+
+ (x1, _) = torch.max(x, dim=2)
+ x2 = torch.mean(x, dim=2)
+ x = x1 + x2
+ x = F.dropout(x, p=0.5, training=self.training)
+ x = F.relu_(self.fc1(x))
+ embedding = F.dropout(x, p=0.5, training=self.training)
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
+
+ output_dict = {
+ "clipwise_output": clipwise_output,
+ "embedding": embedding,
+ "fine_grained_embedding": latent_output,
+ }
+ return output_dict
+
+
+class Cnn6(nn.Module):
+ def __init__(
+ self,
+ sample_rate,
+ window_size,
+ hop_size,
+ mel_bins,
+ fmin,
+ fmax,
+ classes_num,
+ enable_fusion=False,
+ fusion_type="None",
+ ):
+ super(Cnn6, self).__init__()
+
+ window = "hann"
+ center = True
+ pad_mode = "reflect"
+ ref = 1.0
+ amin = 1e-10
+ top_db = None
+
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+
+ # Spectrogram extractor
+ self.spectrogram_extractor = Spectrogram(
+ n_fft=window_size,
+ hop_length=hop_size,
+ win_length=window_size,
+ window=window,
+ center=center,
+ pad_mode=pad_mode,
+ freeze_parameters=True,
+ )
+
+ # Logmel feature extractor
+ self.logmel_extractor = LogmelFilterBank(
+ sr=sample_rate,
+ n_fft=window_size,
+ n_mels=mel_bins,
+ fmin=fmin,
+ fmax=fmax,
+ ref=ref,
+ amin=amin,
+ top_db=top_db,
+ freeze_parameters=True,
+ )
+
+ # Spec augmenter
+ self.spec_augmenter = SpecAugmentation(
+ time_drop_width=64,
+ time_stripes_num=2,
+ freq_drop_width=8,
+ freq_stripes_num=2,
+ )
+
+ self.bn0 = nn.BatchNorm2d(64)
+
+ self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
+ self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
+ self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
+ self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
+
+ self.fc1 = nn.Linear(512, 512, bias=True)
+ self.fc_audioset = nn.Linear(512, classes_num, bias=True)
+
+ self.init_weight()
+
+ def init_weight(self):
+ init_bn(self.bn0)
+ init_layer(self.fc1)
+ init_layer(self.fc_audioset)
+
+ def forward(self, input, mixup_lambda=None, device=None):
+ """
+ Input: (batch_size, data_length)"""
+
+ x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
+
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+
+ if self.training:
+ x = self.spec_augmenter(x)
+
+ # Mixup on spectrogram
+ if self.training and mixup_lambda is not None:
+ x = do_mixup(x, mixup_lambda)
+
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = torch.mean(x, dim=3)
+
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x = latent_x1 + latent_x2
+ latent_x = latent_x.transpose(1, 2)
+ latent_x = F.relu_(self.fc1(latent_x))
+ latent_output = interpolate(latent_x, 16)
+
+ (x1, _) = torch.max(x, dim=2)
+ x2 = torch.mean(x, dim=2)
+ x = x1 + x2
+ x = F.dropout(x, p=0.5, training=self.training)
+ x = F.relu_(self.fc1(x))
+ embedding = F.dropout(x, p=0.5, training=self.training)
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
+
+ output_dict = {
+ "clipwise_output": clipwise_output,
+ "embedding": embedding,
+ "fine_grained_embedding": latent_output,
+ }
+
+ return output_dict
+
+
+class Cnn10(nn.Module):
+ def __init__(
+ self,
+ sample_rate,
+ window_size,
+ hop_size,
+ mel_bins,
+ fmin,
+ fmax,
+ classes_num,
+ enable_fusion=False,
+ fusion_type="None",
+ ):
+ super(Cnn10, self).__init__()
+
+ window = "hann"
+ center = True
+ pad_mode = "reflect"
+ ref = 1.0
+ amin = 1e-10
+ top_db = None
+
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+
+ # Spectrogram extractor
+ self.spectrogram_extractor = Spectrogram(
+ n_fft=window_size,
+ hop_length=hop_size,
+ win_length=window_size,
+ window=window,
+ center=center,
+ pad_mode=pad_mode,
+ freeze_parameters=True,
+ )
+
+ # Logmel feature extractor
+ self.logmel_extractor = LogmelFilterBank(
+ sr=sample_rate,
+ n_fft=window_size,
+ n_mels=mel_bins,
+ fmin=fmin,
+ fmax=fmax,
+ ref=ref,
+ amin=amin,
+ top_db=top_db,
+ freeze_parameters=True,
+ )
+
+ # Spec augmenter
+ self.spec_augmenter = SpecAugmentation(
+ time_drop_width=64,
+ time_stripes_num=2,
+ freq_drop_width=8,
+ freq_stripes_num=2,
+ )
+
+ self.bn0 = nn.BatchNorm2d(64)
+
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
+
+ self.fc1 = nn.Linear(1024, 1024, bias=True)
+ self.fc_audioset = nn.Linear(1024, classes_num, bias=True)
+
+ self.init_weight()
+
+ def init_weight(self):
+ init_bn(self.bn0)
+ init_layer(self.fc1)
+ init_layer(self.fc_audioset)
+
+ def forward(self, input, mixup_lambda=None, device=None):
+ """
+ Input: (batch_size, data_length)"""
+
+ x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
+
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+
+ if self.training:
+ x = self.spec_augmenter(x)
+
+ # Mixup on spectrogram
+ if self.training and mixup_lambda is not None:
+ x = do_mixup(x, mixup_lambda)
+
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = torch.mean(x, dim=3)
+
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x = latent_x1 + latent_x2
+ latent_x = latent_x.transpose(1, 2)
+ latent_x = F.relu_(self.fc1(latent_x))
+ latent_output = interpolate(latent_x, 32)
+
+ (x1, _) = torch.max(x, dim=2)
+ x2 = torch.mean(x, dim=2)
+ x = x1 + x2
+ x = F.dropout(x, p=0.5, training=self.training)
+ x = F.relu_(self.fc1(x))
+ embedding = F.dropout(x, p=0.5, training=self.training)
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
+
+ output_dict = {
+ "clipwise_output": clipwise_output,
+ "embedding": embedding,
+ "fine_grained_embedding": latent_output,
+ }
+
+ return output_dict
+
+
+def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"):
+ try:
+ ModelProto = eval(audio_cfg.model_name)
+ model = ModelProto(
+ sample_rate=audio_cfg.sample_rate,
+ window_size=audio_cfg.window_size,
+ hop_size=audio_cfg.hop_size,
+ mel_bins=audio_cfg.mel_bins,
+ fmin=audio_cfg.fmin,
+ fmax=audio_cfg.fmax,
+ classes_num=audio_cfg.class_num,
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type,
+ )
+ return model
+ except:
+ raise RuntimeError(
+ f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
+ )
diff --git a/audioldm2/clap/open_clip/pretrained.py b/audioldm2/clap/open_clip/pretrained.py
new file mode 100755
index 0000000000000000000000000000000000000000..e211d8b5b59320a599e62605f1dee6199f317253
--- /dev/null
+++ b/audioldm2/clap/open_clip/pretrained.py
@@ -0,0 +1,167 @@
+import hashlib
+import os
+import urllib
+import warnings
+
+from tqdm import tqdm
+
+_RN50 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
+ cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
+)
+
+_RN50_quickgelu = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
+ cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
+)
+
+_RN101 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
+)
+
+_RN101_quickgelu = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
+)
+
+_RN50x4 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
+)
+
+_RN50x16 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
+)
+
+_RN50x64 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
+)
+
+_VITB32 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+ laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
+ laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
+ laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
+)
+
+_VITB32_quickgelu = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+ laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
+ laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
+ laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
+)
+
+_VITB16 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
+)
+
+_VITL14 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
+)
+
+_PRETRAINED = {
+ "RN50": _RN50,
+ "RN50-quickgelu": _RN50_quickgelu,
+ "RN101": _RN101,
+ "RN101-quickgelu": _RN101_quickgelu,
+ "RN50x4": _RN50x4,
+ "RN50x16": _RN50x16,
+ "ViT-B-32": _VITB32,
+ "ViT-B-32-quickgelu": _VITB32_quickgelu,
+ "ViT-B-16": _VITB16,
+ "ViT-L-14": _VITL14,
+}
+
+
+def list_pretrained(as_str: bool = False):
+ """returns list of pretrained models
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
+ """
+ return [
+ ":".join([k, t]) if as_str else (k, t)
+ for k in _PRETRAINED.keys()
+ for t in _PRETRAINED[k].keys()
+ ]
+
+
+def list_pretrained_tag_models(tag: str):
+ """return all models having the specified pretrain tag"""
+ models = []
+ for k in _PRETRAINED.keys():
+ if tag in _PRETRAINED[k]:
+ models.append(k)
+ return models
+
+
+def list_pretrained_model_tags(model: str):
+ """return all pretrain tags for the specified model architecture"""
+ tags = []
+ if model in _PRETRAINED:
+ tags.extend(_PRETRAINED[model].keys())
+ return tags
+
+
+def get_pretrained_url(model: str, tag: str):
+ if model not in _PRETRAINED:
+ return ""
+ model_pretrained = _PRETRAINED[model]
+ if tag not in model_pretrained:
+ return ""
+ return model_pretrained[tag]
+
+
+def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")):
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ if "openaipublic" in url:
+ expected_sha256 = url.split("/")[-2]
+ else:
+ expected_sha256 = ""
+
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if expected_sha256:
+ if (
+ hashlib.sha256(open(download_target, "rb").read()).hexdigest()
+ == expected_sha256
+ ):
+ return download_target
+ else:
+ warnings.warn(
+ f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
+ )
+ else:
+ return download_target
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(
+ total=int(source.info().get("Content-Length")),
+ ncols=80,
+ unit="iB",
+ unit_scale=True,
+ ) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if (
+ expected_sha256
+ and hashlib.sha256(open(download_target, "rb").read()).hexdigest()
+ != expected_sha256
+ ):
+ raise RuntimeError(
+ f"Model has been downloaded but the SHA256 checksum does not not match"
+ )
+
+ return download_target
diff --git a/audioldm2/clap/open_clip/timm_model.py b/audioldm2/clap/open_clip/timm_model.py
new file mode 100755
index 0000000000000000000000000000000000000000..b8486b9e62580bb65f0f50a0a7000890cb7ee42d
--- /dev/null
+++ b/audioldm2/clap/open_clip/timm_model.py
@@ -0,0 +1,112 @@
+""" timm model adapter
+
+Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
+"""
+from collections import OrderedDict
+
+import torch.nn as nn
+
+try:
+ import timm
+ from timm.models.layers import Mlp, to_2tuple
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
+ from timm.models.layers.attention_pool2d import (
+ AttentionPool2d as AbsAttentionPool2d,
+ )
+except ImportError:
+ timm = None
+
+from .utils import freeze_batch_norm_2d
+
+
+class TimmModel(nn.Module):
+ """timm model adapter
+ # FIXME this adapter is a work in progress, may change in ways that break weight compat
+ """
+
+ def __init__(
+ self,
+ model_name,
+ embed_dim,
+ image_size=224,
+ pool="avg",
+ proj="linear",
+ drop=0.0,
+ pretrained=False,
+ ):
+ super().__init__()
+ if timm is None:
+ raise RuntimeError("Please `pip install timm` to use timm models.")
+
+ self.image_size = to_2tuple(image_size)
+ self.trunk = timm.create_model(model_name, pretrained=pretrained)
+ feat_size = self.trunk.default_cfg.get("pool_size", None)
+ feature_ndim = 1 if not feat_size else 2
+ if pool in ("abs_attn", "rot_attn"):
+ assert feature_ndim == 2
+ # if attn pooling used, remove both classifier and default pool
+ self.trunk.reset_classifier(0, global_pool="")
+ else:
+ # reset global pool if pool config set, otherwise leave as network default
+ reset_kwargs = dict(global_pool=pool) if pool else {}
+ self.trunk.reset_classifier(0, **reset_kwargs)
+ prev_chs = self.trunk.num_features
+
+ head_layers = OrderedDict()
+ if pool == "abs_attn":
+ head_layers["pool"] = AbsAttentionPool2d(
+ prev_chs, feat_size=feat_size, out_features=embed_dim
+ )
+ prev_chs = embed_dim
+ elif pool == "rot_attn":
+ head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
+ prev_chs = embed_dim
+ else:
+ assert proj, "projection layer needed if non-attention pooling is used."
+
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
+ if proj == "linear":
+ head_layers["drop"] = nn.Dropout(drop)
+ head_layers["proj"] = nn.Linear(prev_chs, embed_dim)
+ elif proj == "mlp":
+ head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
+
+ self.head = nn.Sequential(head_layers)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ """lock modules
+ Args:
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
+ """
+ if not unlocked_groups:
+ # lock full model
+ for param in self.trunk.parameters():
+ param.requires_grad = False
+ if freeze_bn_stats:
+ freeze_batch_norm_2d(self.trunk)
+ else:
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
+ try:
+ # FIXME import here until API stable and in an official release
+ from timm.models.helpers import group_parameters, group_modules
+ except ImportError:
+ raise RuntimeError(
+ "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`"
+ )
+ matcher = self.trunk.group_matcher()
+ gparams = group_parameters(self.trunk, matcher)
+ max_layer_id = max(gparams.keys())
+ max_layer_id = max_layer_id - unlocked_groups
+ for group_idx in range(max_layer_id + 1):
+ group = gparams[group_idx]
+ for param in group:
+ self.trunk.get_parameter(param).requires_grad = False
+ if freeze_bn_stats:
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
+ freeze_batch_norm_2d(self.trunk, gmodules)
+
+ def forward(self, x):
+ x = self.trunk(x)
+ x = self.head(x)
+ return x
diff --git a/audioldm2/clap/open_clip/tokenizer.py b/audioldm2/clap/open_clip/tokenizer.py
new file mode 100755
index 0000000000000000000000000000000000000000..ee4d28450ec5dd12a79daf38cf3088e9e73c2cd5
--- /dev/null
+++ b/audioldm2/clap/open_clip/tokenizer.py
@@ -0,0 +1,197 @@
+""" CLIP tokenizer
+
+Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+import gzip
+import html
+import os
+from functools import lru_cache
+from typing import Union, List
+
+import ftfy
+import regex as re
+import torch
+
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
+ )
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1))
+ + list(range(ord("¡"), ord("¬") + 1))
+ + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
+ merges = merges[1 : 49152 - 256 - 2 + 1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v + "" for v in vocab]
+ for merge in merges:
+ vocab.append("".join(merge))
+ if not special_tokens:
+ special_tokens = ["", ""]
+ else:
+ special_tokens = ["", ""] + special_tokens
+ vocab.extend(special_tokens)
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {t: t for t in special_tokens}
+ special = "|".join(special_tokens)
+ self.pat = re.compile(
+ special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+ re.IGNORECASE,
+ )
+
+ self.vocab_size = len(self.encoder)
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + (token[-1] + "",)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token + ""
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
+ bpe_tokens.extend(
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
+ )
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = "".join([self.decoder[token] for token in tokens])
+ text = (
+ bytearray([self.byte_decoder[c] for c in text])
+ .decode("utf-8", errors="replace")
+ .replace("", " ")
+ )
+ return text
+
+
+_tokenizer = SimpleTokenizer()
+
+
+def tokenize(
+ texts: Union[str, List[str]], context_length: int = 77
+) -> torch.LongTensor:
+ """
+ Returns the tokenized representation of given input string(s)
+
+ Parameters
+ ----------
+ texts : Union[str, List[str]]
+ An input string or a list of input strings to tokenize
+ context_length : int
+ The context length to use; all CLIP models use 77 as the context length
+
+ Returns
+ -------
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = _tokenizer.encoder[""]
+ eot_token = _tokenizer.encoder[""]
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ tokens = tokens[:context_length] # Truncate
+ result[i, : len(tokens)] = torch.tensor(tokens)
+
+ return result
diff --git a/audioldm2/clap/open_clip/transform.py b/audioldm2/clap/open_clip/transform.py
new file mode 100755
index 0000000000000000000000000000000000000000..77aaa722c4a5544ac50de6df35d3e922f63b111d
--- /dev/null
+++ b/audioldm2/clap/open_clip/transform.py
@@ -0,0 +1,45 @@
+from torchvision.transforms import (
+ Normalize,
+ Compose,
+ RandomResizedCrop,
+ InterpolationMode,
+ ToTensor,
+ Resize,
+ CenterCrop,
+)
+
+
+def _convert_to_rgb(image):
+ return image.convert("RGB")
+
+
+def image_transform(
+ image_size: int,
+ is_train: bool,
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711),
+):
+ normalize = Normalize(mean=mean, std=std)
+ if is_train:
+ return Compose(
+ [
+ RandomResizedCrop(
+ image_size,
+ scale=(0.9, 1.0),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ _convert_to_rgb,
+ ToTensor(),
+ normalize,
+ ]
+ )
+ else:
+ return Compose(
+ [
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
+ CenterCrop(image_size),
+ _convert_to_rgb,
+ ToTensor(),
+ normalize,
+ ]
+ )
diff --git a/audioldm2/clap/open_clip/utils.py b/audioldm2/clap/open_clip/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..77875569ff4aff81bf9545ce6ec58e0326d49d0c
--- /dev/null
+++ b/audioldm2/clap/open_clip/utils.py
@@ -0,0 +1,356 @@
+import numpy as np
+import torch
+from torch import nn as nn
+from torchvision.ops.misc import FrozenBatchNorm2d
+import logging
+import h5py
+from tqdm import tqdm
+import random
+import json
+import os
+import pathlib
+
+# TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later.
+dataset_split = {
+ "audiocaps": ["train", "valid", "test"],
+ "audioset": ["balanced_train", "unbalanced_train", "eval"],
+ "BBCSoundEffects": ["train", "test"],
+ "Clotho": ["train", "test", "valid"],
+ "free_to_use_sounds": ["train", "test"],
+ "paramount_motion": ["train", "test"],
+ "sonniss_game_effects": ["train", "test"],
+ "wesoundeffects": ["train", "test"],
+ "MACS": ["train", "test"],
+ "freesound": ["train", "test"],
+ "FSD50K": ["train", "test", "valid"],
+ "fsd50k_class_label": ["train", "test", "valid"],
+ "esc50": ["train", "test"],
+ "audiostock": ["train", "test"],
+ "freesound_no_overlap_noesc50": ["train", "test"],
+ "epidemic_sound_effects": ["train", "test"],
+ "VGGSound": ["train", "test"],
+ "urbansound8k_class_label": ["train", "test"],
+ "audioset_t5": ["balanced_train", "unbalanced_train", "eval"],
+ "epidemic_sound_effects_t5": ["train", "test"],
+ "WavText5K": ["train", "test"],
+ "esc50_no_overlap": ["train", "test"],
+ "usd8k_no_overlap": ["train", "test"],
+ "fsd50k_200_class_label": ["train", "test", "valid"],
+}
+
+
+def freeze_batch_norm_2d(module, module_match={}, name=""):
+ """
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
+
+ Args:
+ module (torch.nn.Module): Any PyTorch module.
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
+ name (str): Full module name (prefix)
+
+ Returns:
+ torch.nn.Module: Resulting module
+
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
+ """
+ res = module
+ is_match = True
+ if module_match:
+ is_match = name in module_match
+ if is_match and isinstance(
+ module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)
+ ):
+ res = FrozenBatchNorm2d(module.num_features)
+ res.num_features = module.num_features
+ res.affine = module.affine
+ if module.affine:
+ res.weight.data = module.weight.data.clone().detach()
+ res.bias.data = module.bias.data.clone().detach()
+ res.running_mean.data = module.running_mean.data
+ res.running_var.data = module.running_var.data
+ res.eps = module.eps
+ else:
+ for child_name, child in module.named_children():
+ full_child_name = ".".join([name, child_name]) if name else child_name
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
+ if new_child is not child:
+ res.add_module(child_name, new_child)
+ return res
+
+
+def exist(dataset_name, dataset_type):
+ """
+ Check if dataset exists
+ """
+ if dataset_type in dataset_split[dataset_name]:
+ return True
+ else:
+ return False
+
+
+def get_tar_path_from_dataset_name(
+ dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None
+):
+ """
+ Get tar path from dataset name and type
+ """
+ output = []
+ for n in dataset_names:
+ if full_dataset is not None and n in full_dataset:
+ current_dataset_types = dataset_split[n]
+ else:
+ current_dataset_types = dataset_types
+ for s in current_dataset_types:
+ tmp = []
+ if islocal:
+ sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json"
+ if not os.path.exists(sizefilepath_):
+ sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
+ else:
+ sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
+ if not os.path.exists(sizefilepath_):
+ continue
+ sizes = json.load(open(sizefilepath_, "r"))
+ for k in sizes.keys():
+ if islocal:
+ tmp.append(f"{dataset_path}/{n}/{s}/{k}")
+ else:
+ tmp.append(
+ f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -"
+ )
+ if proportion != 1:
+ tmp = random.sample(tmp, int(proportion * len(tmp)))
+ output.append(tmp)
+ return sum(output, [])
+
+
+def get_tar_path_from_txts(txt_path, islocal, proportion=1):
+ """
+ Get tar path from txt path
+ """
+ if isinstance(txt_path, (list, tuple)):
+ return sum(
+ [
+ get_tar_path_from_txts(
+ txt_path[i], islocal=islocal, proportion=proportion
+ )
+ for i in range(len(txt_path))
+ ],
+ [],
+ )
+ if isinstance(txt_path, str):
+ with open(txt_path) as f:
+ lines = f.readlines()
+ if islocal:
+ lines = [
+ lines[i]
+ .split("\n")[0]
+ .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/")
+ for i in range(len(lines))
+ ]
+ else:
+ lines = [
+ lines[i].split("\n")[0].replace(".tar", ".tar -")
+ for i in range(len(lines))
+ ]
+ if proportion != 1:
+ print("Sampling tars with proportion of {}".format(proportion))
+ lines = random.sample(lines, int(proportion * len(lines)))
+ return lines
+
+
+def get_mix_lambda(mixup_alpha, batch_size):
+ mixup_lambdas = [
+ np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)
+ ]
+ return np.array(mixup_lambdas).astype(np.float32)
+
+
+def do_mixup(x, mixup_lambda):
+ """
+ Args:
+ x: (batch_size , ...)
+ mixup_lambda: (batch_size,)
+ Returns:
+ out: (batch_size, ...)
+ """
+ out = (
+ x.transpose(0, -1) * mixup_lambda
+ + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
+ ).transpose(0, -1)
+ return out
+
+
+def interpolate(x, ratio):
+ """Interpolate data in time domain. This is used to compensate the
+ resolution reduction in downsampling of a CNN.
+
+ Args:
+ x: (batch_size, time_steps, classes_num)
+ ratio: int, ratio to interpolate
+ Returns:
+ upsampled: (batch_size, time_steps * ratio, classes_num)
+ """
+ (batch_size, time_steps, classes_num) = x.shape
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
+ return upsampled
+
+
+def pad_framewise_output(framewise_output, frames_num):
+ """Pad framewise_output to the same length as input frames. The pad value
+ is the same as the value of the last frame.
+ Args:
+ framewise_output: (batch_size, frames_num, classes_num)
+ frames_num: int, number of frames to pad
+ Outputs:
+ output: (batch_size, frames_num, classes_num)
+ """
+ pad = framewise_output[:, -1:, :].repeat(
+ 1, frames_num - framewise_output.shape[1], 1
+ )
+ """tensor for padding"""
+
+ output = torch.cat((framewise_output, pad), dim=1)
+ """(batch_size, frames_num, classes_num)"""
+
+
+def process_ipc(index_path, classes_num, filename):
+ # load data
+ logging.info("Load Data...............")
+ ipc = [[] for _ in range(classes_num)]
+ with h5py.File(index_path, "r") as f:
+ for i in tqdm(range(len(f["target"]))):
+ t_class = np.where(f["target"][i])[0]
+ for t in t_class:
+ ipc[t].append(i)
+ print(ipc)
+ np.save(filename, ipc)
+ logging.info("Load Data Succeed...............")
+
+
+def save_to_dict(s, o_={}):
+ sp = s.split(": ")
+ o_.update({sp[0]: float(sp[1])})
+ return o_
+
+
+def get_data_from_log(txt_path):
+ """
+ Output dictionary from out.txt log file
+ """
+ with open(txt_path) as f:
+ lines = f.readlines()
+ val_data = {}
+ train_data = {}
+ train_losses = []
+ train_losses_epoch = []
+ for i in range(len(lines)):
+ if "| INFO |" in lines[i]:
+ if "Eval Epoch" in lines[i]:
+ if "val_loss" in lines[i]:
+ # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", ""))
+ line = lines[i].split("Eval Epoch: ")[-1]
+ num_epoch = int(line.split(" ")[0].split(" ")[0])
+ d = {
+ line.split(" ")[0]
+ .split(" ")[1]
+ .replace(":", ""): float(line.split(" ")[0].split(" ")[-1])
+ }
+ for i in range(1, len(line.split(" "))):
+ d = save_to_dict(line.split(" ")[i], d)
+ val_data[num_epoch] = d
+ elif "Train Epoch" in lines[i]:
+ num_epoch = int(lines[i].split("Train Epoch: ")[1][0])
+ loss = float(lines[i].split("Loss: ")[-1].split(" (")[0])
+ train_losses.append(loss)
+ train_losses_epoch.append(num_epoch)
+ for i in range(len(train_losses)):
+ train_data[i] = {
+ "num_epoch": train_losses_epoch[i],
+ "train_loss": train_losses[i],
+ }
+ return train_data, val_data
+
+
+def save_p(obj, filename):
+ import pickle
+
+ try:
+ from deepdiff import DeepDiff
+ except:
+ os.system("pip install deepdiff")
+ from deepdiff import DeepDiff
+ with open(filename, "wb") as file:
+ pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol
+ with open(filename, "rb") as file:
+ z = pickle.load(file)
+ assert (
+ DeepDiff(obj, z, ignore_string_case=True) == {}
+ ), "there is something wrong with the saving process"
+ return
+
+
+def load_p(filename):
+ import pickle
+
+ with open(filename, "rb") as file:
+ z = pickle.load(file)
+ return z
+
+
+def save_json(data, name="data.json"):
+ import json
+
+ with open(name, "w") as fp:
+ json.dump(data, fp)
+ return
+
+
+def load_json(name):
+ import json
+
+ with open(name, "r") as fp:
+ data = json.load(fp)
+ return data
+
+
+def load_class_label(path):
+ # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
+ # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
+ out = None
+ if path is not None:
+ if pathlib.Path(path).suffix in [".pkl", ".pickle"]:
+ out = load_p(path)
+ elif pathlib.Path(path).suffix in [".json", ".txt"]:
+ out = load_json(path)
+ elif pathlib.Path(path).suffix in [".npy", ".npz"]:
+ out = np.load(path)
+ elif pathlib.Path(path).suffix in [".csv"]:
+ import pandas as pd
+
+ out = pd.read_csv(path)
+ return out
+ # if out is None:
+ # return None
+ # else:
+ # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False)
+ # val = Array('i', out.values(), lock=False)
+ # return (key, val)
+
+
+from torch import optim
+
+
+def get_optimizer(params, lr, betas, eps, momentum, optimizer_name):
+ if optimizer_name.lower() == "adamw":
+ optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps)
+ elif optimizer_name.lower() == "sgd":
+ optimizer = optim.SGD(params, lr=lr, momentum=momentum)
+ elif optimizer_name.lower() == "adam":
+ optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps)
+ else:
+ raise ValueError("optimizer name is not correct")
+ return optimizer
diff --git a/audioldm2/clap/training/__init__.py b/audioldm2/clap/training/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/clap/training/audioset_textmap.npy b/audioldm2/clap/training/audioset_textmap.npy
new file mode 100755
index 0000000000000000000000000000000000000000..3da4c92d3819aaec11e5f576464a9973a6df811b
--- /dev/null
+++ b/audioldm2/clap/training/audioset_textmap.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bada103070d92f9eadd33e1b4f45ec8583f59080ef218c966b43294bd4c86d5b
+size 84448
diff --git a/audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz b/audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/audioldm2/clap/training/data.py b/audioldm2/clap/training/data.py
new file mode 100755
index 0000000000000000000000000000000000000000..ae01406c63a9b1c678151f67dacd7ea192cb84f2
--- /dev/null
+++ b/audioldm2/clap/training/data.py
@@ -0,0 +1,865 @@
+import json
+import logging
+import os
+import random
+import h5py
+from dataclasses import dataclass
+import numpy as np
+import pandas as pd
+import torch
+import torchvision.datasets as datasets
+from PIL import Image
+from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
+from torch.utils.data.distributed import DistributedSampler
+import soundfile as sf
+import io
+from pathlib import Path
+# import wget
+
+from audioldm2.clap.open_clip.utils import get_tar_path_from_dataset_name
+from audioldm2.clap.open_clip.utils import load_class_label
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+try:
+ import torchaudio
+except ImportError:
+ torchaudio = None
+
+from audioldm2.clap.open_clip import tokenize
+
+
+def tokenizer(text):
+ return tokenize(text).squeeze(0)
+
+
+from transformers import RobertaTokenizer
+
+tokenize = RobertaTokenizer.from_pretrained("roberta-base")
+
+
+def tokenizer(text):
+ result = tokenize(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ )
+ return {k: v.squeeze(0) for k, v in result.items()}
+
+
+# initizlied the audioset map
+_AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy")
+_AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True)
+
+
+def int16_to_float32(x):
+ return (x / 32767.0).astype(np.float32)
+
+
+def float32_to_int16(x):
+ x = np.clip(x, a_min=-1.0, a_max=1.0)
+ return (x * 32767.0).astype(np.int16)
+
+
+# For Toy Dataset
+class ToyDataset(Dataset):
+ def __init__(self, index_path, ipc, config, eval_mode=False):
+ """Toy Dataset for testing the audioset input with text labels
+ Parameters
+ ----------
+ index_path: str
+ the link to the h5 file of each audio
+ idc: str
+ the link to the npy file, the number of samples in each class
+ config: dict
+ the audio cfg file
+ eval_model (bool): to indicate if the dataset is a testing dataset
+ """
+ self.audio_cfg = config["audio_cfg"]
+ self.text_cfg = config["text_cfg"]
+ self.fp = h5py.File(index_path, "r")
+ self.ipc = np.load(ipc, allow_pickle=True)
+ self.total_size = len(self.fp["audio_name"])
+ self.classes_num = self.audio_cfg["class_num"]
+ self.eval_mode = eval_mode
+
+ if not eval_mode:
+ self.generate_queue()
+ else:
+ self.queue = []
+ for i in range(self.total_size):
+ target = self.fp["target"][i]
+ if np.sum(target) > 0:
+ self.queue.append(i)
+ self.total_size = len(self.queue)
+ logging.info("total dataset size: %d" % (self.total_size))
+ logging.info("class num: %d" % (self.classes_num))
+
+ def time_shifting(self, x):
+ frame_num = len(x)
+ shift_len = random.randint(0, frame_num - 1)
+ new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0)
+ return new_sample
+
+ def generate_queue(self):
+ self.queue = []
+ while len(self.queue) < self.total_size:
+ class_set = [*range(self.classes_num)]
+ random.shuffle(class_set)
+ self.queue += [
+ self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set
+ ]
+ self.queue = self.queue[: self.total_size]
+
+ logging.info("queue regenerated:%s" % (self.queue[-5:]))
+
+ def crop_wav(self, x):
+ crop_size = self.audio_cfg["crop_size"]
+ crop_pos = random.randint(0, len(x) - crop_size - 1)
+ return x[crop_pos : crop_pos + crop_size]
+
+ def prompt_text(self, target):
+ events = _AUDIOSET_MAP[np.where(target > 0)]
+ event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1]
+ text = tokenize(event_text)[0]
+ return text
+
+ def __getitem__(self, index):
+ """Load waveform, text, and target of an audio clip
+
+ Parameters
+ ----------
+ index: int
+ the index number
+ Return
+ ------
+ output: dict {
+ "hdf5_path": str,
+ "index_in_hdf5": int,
+ "audio_name": str,
+ "waveform": list (audio_length,),
+ "target": list (class_num, ),
+ "text": torch.tensor (context_length,)
+ }
+ the output dictionary
+ """
+ s_index = self.queue[index]
+
+ audio_name = self.fp["audio_name"][s_index].decode()
+ # Hardcode here CHANGE
+ hdf5_path = (
+ self.fp["hdf5_path"][s_index]
+ .decode()
+ .replace(
+ "../workspace",
+ "/home/la/kechen/Research/ke_zsasp/workspace",
+ )
+ )
+ r_idx = self.fp["index_in_hdf5"][s_index]
+ target = self.fp["target"][s_index].astype(np.float32)
+ text = self.prompt_text(target)
+ with h5py.File(hdf5_path, "r") as f:
+ waveform = int16_to_float32(f["waveform"][r_idx])[
+ : self.audio_cfg["clip_samples"]
+ ]
+ assert (
+ len(waveform) == self.audio_cfg["clip_samples"]
+ ), "The sample length is not match"
+ # Time shift
+ # if (self.config.enable_time_shift) and (not self.eval_mode):
+ # waveform = self.time_shifting(waveform)
+ # # Label Enhance
+ # if (self.config.crop_size is not None) and (not self.eval_mode):
+ # waveform = self.crop_wav(waveform)
+ # # the label enhance rate is fixed 0.5
+ # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5:
+ # kidx = np.where(target)[0]
+ # for k in kidx:
+ # for add_key in self.class_map[k][1]:
+ # target[add_key] = 1.0
+ # if len(self.class_map[k][2]) > 0:
+ # add_key = random.choice(self.class_map[k][2])
+ # target[add_key] = 1.0
+
+ # missing the text input
+ mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :]
+ mel_spec = (
+ torch.cat(
+ [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0
+ )
+ .cpu()
+ .numpy()
+ )
+ longer = random.choice([True, False])
+ if longer == False:
+ mel_spec[1:, :, :] = 0.0
+ data_dict = {
+ "hdf5_path": hdf5_path,
+ "index_in_hdf5": r_idx,
+ "audio_name": audio_name,
+ "waveform": waveform,
+ "class_label": target,
+ "text": text,
+ "longer": longer,
+ "mel_fusion": mel_spec,
+ }
+ return data_dict
+
+ def __len__(self):
+ return self.total_size
+
+
+class CsvDataset(Dataset):
+ def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
+ logging.debug(f"Loading csv data from {input_filename}.")
+ df = pd.read_csv(input_filename, sep=sep)
+
+ self.images = df[img_key].tolist()
+ self.captions = df[caption_key].tolist()
+ self.transforms = transforms
+ logging.debug("Done loading data.")
+
+ def __len__(self):
+ return len(self.captions)
+
+ def __getitem__(self, idx):
+ images = self.transforms(Image.open(str(self.images[idx])))
+ texts = tokenize([str(self.captions[idx])])[0]
+ return images, texts
+
+
+@dataclass
+class DataInfo:
+ dataloader: DataLoader
+ sampler: DistributedSampler
+
+
+def preprocess_txt(text):
+ return tokenize([str(text)])[0]
+
+
+# def get_dataset_size(shards, sizefilepath_=None, is_local=True):
+# if isinstance(shards, list):
+# size_list = []
+# for s in shards:
+# size_list.append(
+# get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0]
+# )
+# else:
+# if not is_local:
+# for n in dataset_split.keys():
+# if n in shards.split("/"):
+# break
+# for s in dataset_split[n]:
+# if s in shards.split("/"):
+# break
+# sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
+# shards_list = list(braceexpand.braceexpand(shards))
+# dir_path = os.path.dirname(shards)
+# if sizefilepath_ is not None:
+# sizes = json.load(open(sizefilepath_, "r"))
+# total_size = sum(
+# [
+# int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))])
+# for shard in shards_list
+# ]
+# )
+# else:
+# sizes_filename = os.path.join(dir_path, "sizes.json")
+# len_filename = os.path.join(dir_path, "__len__")
+# if os.path.exists(sizes_filename):
+# sizes = json.load(open(sizes_filename, "r"))
+# total_size = sum(
+# [int(sizes[os.path.basename(shard)]) for shard in shards_list]
+# )
+# elif os.path.exists(len_filename):
+# # FIXME this used to be eval(open(...)) but that seemed rather unsafe
+# total_size = ast.literal_eval(open(len_filename, "r").read())
+# else:
+# raise Exception(
+# "Cannot find sizes file for dataset. Please specify the path to the file."
+# )
+# # total_size = None # num samples undefined
+# # some common dataset sizes (at time of authors last download)
+# # cc3m-train: 2905954
+# # cc12m: 10968539
+# # LAION-400m: 407332084
+# num_shards = len(shards_list)
+# if isinstance(shards, list):
+# return sum(size_list), len(shards)
+# else:
+# return total_size, num_shards
+
+
+def get_imagenet(args, preprocess_fns, split):
+ assert split in ["train", "val", "v2"]
+ is_train = split == "train"
+ preprocess_train, preprocess_val = preprocess_fns
+
+ if split == "v2":
+ from imagenetv2_pytorch import ImageNetV2Dataset
+
+ dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
+ else:
+ if is_train:
+ data_path = args.imagenet_train
+ preprocess_fn = preprocess_train
+ else:
+ data_path = args.imagenet_val
+ preprocess_fn = preprocess_val
+ assert data_path
+
+ dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
+
+ if is_train:
+ idxs = np.zeros(len(dataset.targets))
+ target_array = np.array(dataset.targets)
+ k = 50
+ for c in range(1000):
+ m = target_array == c
+ n = len(idxs[m])
+ arr = np.zeros(n)
+ arr[:k] = 1
+ np.random.shuffle(arr)
+ idxs[m] = arr
+
+ idxs = idxs.astype("int")
+ sampler = SubsetRandomSampler(np.where(idxs)[0])
+ else:
+ sampler = None
+
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=args.batch_size,
+ num_workers=args.workers,
+ sampler=sampler,
+ )
+
+ return DataInfo(dataloader, sampler)
+
+
+def count_samples(dataloader):
+ os.environ["WDS_EPOCH"] = "0"
+ n_elements, n_batches = 0, 0
+ for images, texts in dataloader:
+ n_batches += 1
+ n_elements += len(images)
+ assert len(images) == len(texts)
+ return n_elements, n_batches
+
+
+def filter_no_caption(sample):
+ return "txt" in sample
+
+
+def log_and_continue(exn):
+ """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
+ logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
+ return True
+
+
+_SHARD_SHUFFLE_SIZE = 2000
+_SHARD_SHUFFLE_INITIAL = 500
+_SAMPLE_SHUFFLE_SIZE = 5000
+_SAMPLE_SHUFFLE_INITIAL = 1000
+
+
+# def sample_prop(sizefile, inputs, proportion, is_local=True):
+# """
+# Sample a proportion of the data.
+# """
+# file_path_dict = {
+# os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0]
+# for i in range(len(inputs))
+# }
+# sampled_filepath_dict = {}
+# sampled_size_dict = {}
+# if not is_local:
+# if os.path.exists("sizes.json"):
+# os.remove("sizes.json")
+# wget.download(sizefile, "sizes.json")
+# sizefile = "sizes.json"
+# with open(sizefile, "r", encoding="UTF-8") as f:
+# load_dict = json.load(f)
+# L = int(len(file_path_dict) * proportion)
+# subkeys = random.sample(file_path_dict.keys(), L)
+# for k in subkeys:
+# sampled_size_dict[k] = load_dict[k]
+# sampled_filepath_dict[k] = file_path_dict[k]
+# return (
+# sum(sampled_size_dict.values()),
+# L,
+# [os.path.join(v, k) for k, v in sampled_filepath_dict.items()],
+# sampled_size_dict,
+# )
+
+
+def get_mel(audio_data, audio_cfg):
+ # mel shape: (n_mels, T)
+ mel = torchaudio.transforms.MelSpectrogram(
+ sample_rate=audio_cfg["sample_rate"],
+ n_fft=audio_cfg["window_size"],
+ win_length=audio_cfg["window_size"],
+ hop_length=audio_cfg["hop_size"],
+ center=True,
+ pad_mode="reflect",
+ power=2.0,
+ norm=None,
+ onesided=True,
+ n_mels=64,
+ f_min=audio_cfg["fmin"],
+ f_max=audio_cfg["fmax"],
+ ).to(audio_data.device)
+ mel = mel(audio_data)
+ # we use log mel spectrogram as input
+ mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
+ return mel.T # (T, n_mels)
+
+
+def get_audio_features(
+ audio_data, mel, max_len, data_truncating, data_filling, audio_cfg
+):
+ """
+ Calculate and add audio features to sample.
+ Sample: a dict containing all the data of current sample.
+ audio_data: a tensor of shape (T) containing audio data.
+ max_len: the maximum length of audio data.
+ data_truncating: the method of truncating data.
+ data_filling: the method of filling data.
+ audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg'].
+ """
+ sample = {}
+
+ # assert audio_data.size(-1) <= max_len, str(audio_data.size())
+
+ # split to three parts
+ chunk_frames = (
+ max_len // audio_cfg["hop_size"] + 1
+ ) # the +1 related to how the spectrogram is computed
+ mel = mel[:chunk_frames]
+
+ audio_data = audio_data[..., :max_len]
+ sample["mel_fusion"] = mel
+ longer = torch.tensor([True])
+
+ sample["longer"] = longer
+ sample["waveform"] = audio_data
+
+ return sample
+
+
+def preprocess(
+ sample,
+ audio_ext,
+ text_ext,
+ max_len,
+ audio_cfg,
+ class_index_dict=None,
+ data_filling="pad",
+ data_truncating="rand_trunc",
+ text_augment_selection=None,
+):
+ """
+ Preprocess a single sample for wdsdataloader.
+ """
+ audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
+ audio_data = int16_to_float32(float32_to_int16(audio_data))
+ audio_data = torch.tensor(audio_data).float()
+
+ # TODO: (yusong) to be include in the future
+ # # if torchaudio not installed, use soundfile to load audio
+ # if torchaudio is None:
+ # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
+ # audio_data = torch.tensor(audio_data).float()
+ # else:
+ # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py
+ # with tempfile.TemporaryDirectory() as dirname:
+ # os.makedirs(dirname, exist_ok=True)
+ # fname = os.path.join(dirname, f"file.flac")
+ # with open(fname, "wb") as stream:
+ # stream.write(sample[audio_ext])
+ # audio_data, orig_sr = torchaudio.load(fname)
+ # audio_data = audio_data[0, :].float()
+
+ sample = get_audio_features(
+ sample, audio_data, max_len, data_truncating, data_filling, audio_cfg
+ )
+ del sample[audio_ext]
+
+ try:
+ json_dict_raw = json.loads(sample[text_ext].decode("utf-8"))
+ except:
+ print("sample[__url__]:", sample["__url__"])
+
+ # For selecting augmented text from dataset
+ if text_augment_selection is None or text_augment_selection == "none":
+ texts = json_dict_raw["text"]
+ elif text_augment_selection == "all":
+ if "text_augment_all" in json_dict_raw.keys():
+ texts = json_dict_raw["text_augment_all"]
+ else:
+ texts = json_dict_raw["text"]
+ elif text_augment_selection == "augment_only":
+ if "text_augment_all" in json_dict_raw.keys():
+ if json_dict_raw["text_augment_t5"] is None:
+ texts = json_dict_raw["text"]
+ else:
+ texts = json_dict_raw["text_augment_t5"]
+ else:
+ texts = json_dict_raw["text"]
+ else:
+ raise NotImplementedError(
+ f"text_augment_selection {text_augment_selection} not implemented"
+ )
+ sample["full_text"] = texts
+
+ if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1:
+ texts = random.choice(texts)
+ sample["raw_text"] = texts
+ sample["text"] = tokenizer(texts) # text shape: [num_token]
+ if class_index_dict is not None:
+ # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
+ # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
+ # key, val = class_index_dict
+ # key = key[:].split('\n')
+ # _dict = {k: v for k, v in zip(key, val)}
+ sample["class_label"] = np.zeros(len(class_index_dict.keys()))
+ for x in json_dict_raw["tag"]:
+ sample["class_label"][class_index_dict[x]] = 1
+ sample["class_label"] = torch.tensor(sample["class_label"]).float()
+ del sample[text_ext]
+ sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext
+ sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext
+ sample["audio_orig_sr"] = orig_sr
+ return sample
+
+
+def collate_fn(batch):
+ """
+ Collate function for wdsdataloader.
+ batch: a list of dict, each dict is a sample
+ """
+ # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend.
+ batch_dict = {}
+ for k in batch[0].keys():
+ if isinstance(batch[0][k], dict): # dealwith bert tokenizer output
+ batch_dict[k] = {}
+ for kk in batch[0][k].keys():
+ tmp = []
+ for i in range(len(batch)):
+ tmp.append(batch[i][k][kk])
+ batch_dict[k][kk] = torch.vstack(tmp)
+ elif isinstance(batch[0][k], torch.Tensor):
+ batch_dict[k] = torch.stack([sample[k] for sample in batch])
+ elif isinstance(batch[0][k], np.ndarray):
+ batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch]))
+ else:
+ batch_dict[k] = [sample[k] for sample in batch]
+ return batch_dict
+
+
+# def get_wds_dataset(
+# args,
+# model_cfg,
+# is_train,
+# audio_ext="flac",
+# text_ext="json",
+# max_len=480000,
+# proportion=1.0,
+# sizefilepath_=None,
+# is_local=None,
+# ):
+# """
+# Get a dataset for wdsdataloader.
+# """
+# if is_local is None and (not args.remotedata is None):
+# is_local = not args.remotedata
+
+# input_shards = args.train_data if is_train else args.val_data
+# assert input_shards is not None
+
+# if not sizefilepath_ is None:
+# sizefilepath = sizefilepath_
+# else:
+# sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json")
+
+# if proportion != 1.0:
+# num_samples, num_shards, input_shards, _ = sample_prop(
+# sizefilepath, input_shards, proportion, is_local=is_local
+# )
+# else:
+# num_samples, num_shards = get_dataset_size(
+# input_shards, sizefilepath_=sizefilepath_, is_local=is_local
+# )
+
+# if not num_samples:
+# if is_train:
+# num_samples = args.train_num_samples
+# if not num_samples:
+# raise RuntimeError(
+# "Currently, number of dataset samples must be specified for training dataset. "
+# "Please specify via `--train-num-samples` if no dataset length info present."
+# )
+# else:
+# num_samples = (
+# args.val_num_samples or 0
+# ) # eval will just exhaust the iterator if not specified
+
+# pipeline = [wds.SimpleShardList(input_shards)]
+# # at this point we have an iterator over all the shards
+# # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node
+# if is_train or args.parallel_eval:
+# pipeline.extend(
+# [
+# wds.detshuffle(
+# bufsize=_SHARD_SHUFFLE_SIZE,
+# initial=_SHARD_SHUFFLE_INITIAL,
+# seed=args.seed,
+# ),
+# wds.split_by_node,
+# wds.split_by_worker,
+# # at this point, we have an iterator over the shards assigned to each worker at each node
+# wds.tarfile_to_samples(handler=log_and_continue),
+# wds.shuffle(
+# bufsize=_SAMPLE_SHUFFLE_SIZE,
+# initial=_SAMPLE_SHUFFLE_INITIAL,
+# rng=random.Random(args.seed),
+# ),
+# # wds.repeatedly, # FIXME determine if this is beneficial
+# ]
+# )
+# else:
+# pipeline.extend(
+# [
+# wds.split_by_worker,
+# # at this point, we have an iterator over the shards assigned to each worker
+# wds.tarfile_to_samples(handler=log_and_continue),
+# ]
+# )
+# pipeline.append(
+# wds.map(
+# partial(
+# preprocess,
+# audio_ext=audio_ext,
+# text_ext=text_ext,
+# max_len=max_len,
+# audio_cfg=model_cfg["audio_cfg"],
+# class_index_dict=copy.deepcopy(args.class_index_dict),
+# data_filling=args.data_filling,
+# data_truncating=args.data_truncating,
+# text_augment_selection=args.text_augment_selection,
+# )
+# ),
+# )
+
+# pipeline.append(
+# wds.batched(
+# args.batch_size,
+# partial=not (is_train or args.parallel_eval),
+# collation_fn=collate_fn,
+# )
+# )
+
+# dataset = wds.DataPipeline(*pipeline)
+# if is_train or args.parallel_eval:
+# # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples.
+# # (yusong): See comments below.
+# # roll over and repeat a few samples to get same number of full batches on each node
+# global_batch_size = args.batch_size * args.world_size
+# num_batches = math.ceil(num_samples / global_batch_size)
+# num_workers = max(1, args.workers)
+# num_worker_batches = math.ceil(
+# num_batches / num_workers
+# ) # per dataloader worker
+# num_batches = num_worker_batches * num_workers
+# num_samples = num_batches * global_batch_size
+# dataset = dataset.with_epoch(
+# num_worker_batches
+# ) # each worker is iterating over this
+# else:
+# # last batches are partial, eval is done on single (master) node
+# num_batches = math.ceil(num_samples / args.batch_size)
+
+# kwargs = {}
+# if args.horovod: # multi-node training on summit
+# kwargs["multiprocessing_context"] = "forkserver"
+
+# dataloader = wds.WebLoader(
+# dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs
+# )
+
+# # FIXME not clear which approach is better, with_epoch before vs after dataloader?
+# # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
+# # if is_train:
+# # # roll over and repeat a few samples to get same number of full batches on each node
+# # global_batch_size = args.batch_size * args.world_size
+# # num_batches = math.ceil(num_samples / global_batch_size)
+# # num_workers = max(1, args.workers)
+# # num_batches = math.ceil(num_batches / num_workers) * num_workers
+# # num_samples = num_batches * global_batch_size
+# # dataloader = dataloader.with_epoch(num_batches)
+# # else:
+# # # last batches are partial, eval is done on single (master) node
+# # num_batches = math.ceil(num_samples / args.batch_size)
+
+# # add meta-data to dataloader instance for convenience
+# dataloader.num_batches = num_batches
+# dataloader.num_samples = num_samples
+
+# return DataInfo(dataloader, None)
+
+
+def wds_batch_list2dict(
+ batch,
+ keys=[
+ "__url__",
+ "__key__",
+ "waveform",
+ "text",
+ "raw_text",
+ "audio_name",
+ "text_name",
+ "audio_orig_sr",
+ ],
+):
+ """
+ Return a dictionary of the batch, with keys as the names of the fields.
+ """
+ assert len(keys) == len(
+ batch
+ ), "batch must have same number of keys as keys argument"
+ return {keys[i]: batch[i] for i in range(len(batch))}
+
+
+def get_csv_dataset(args, preprocess_fn, is_train):
+ input_filename = args.train_data if is_train else args.val_data
+ assert input_filename
+ dataset = CsvDataset(
+ input_filename,
+ preprocess_fn,
+ img_key=args.csv_img_key,
+ caption_key=args.csv_caption_key,
+ sep=args.csv_separator,
+ )
+ num_samples = len(dataset)
+ sampler = DistributedSampler(dataset) if args.distributed and is_train else None
+ shuffle = is_train and sampler is None
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=args.batch_size,
+ shuffle=shuffle,
+ num_workers=args.workers,
+ pin_memory=True,
+ sampler=sampler,
+ drop_last=is_train,
+ )
+ dataloader.num_samples = num_samples
+ dataloader.num_batches = len(dataloader)
+
+ return DataInfo(dataloader, sampler)
+
+
+def get_toy_dataset(args, model_cfg, is_train):
+ index_path = args.train_data if is_train else args.val_data
+ ipc_path = args.train_ipc if is_train else args.val_ipc
+ assert index_path and ipc_path
+ eval_mode = not is_train
+ dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode)
+
+ num_samples = len(dataset)
+ sampler = (
+ DistributedSampler(dataset, shuffle=False)
+ if args.distributed and is_train
+ else None
+ )
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=args.batch_size,
+ shuffle=False,
+ num_workers=args.workers,
+ sampler=sampler,
+ drop_last=is_train,
+ )
+ dataloader.num_samples = num_samples
+ dataloader.num_batches = len(dataloader)
+
+ return DataInfo(dataloader, sampler)
+
+
+def get_dataset_fn(data_path, dataset_type):
+ if dataset_type == "webdataset":
+ return get_wds_dataset
+ elif dataset_type == "csv":
+ return get_csv_dataset
+ elif dataset_type == "auto":
+ ext = data_path.split(".")[-1]
+ if ext in ["csv", "tsv"]:
+ return get_csv_dataset
+ elif ext in ["tar"]:
+ return get_wds_dataset
+ else:
+ raise ValueError(
+ f"Tried to figure out dataset type, but failed for extention {ext}."
+ )
+ elif dataset_type == "toy":
+ return get_toy_dataset
+ else:
+ raise ValueError(f"Unsupported dataset type: {dataset_type}")
+
+
+def get_data(args, model_cfg):
+ data = {}
+
+ args.class_index_dict = load_class_label(args.class_label_path)
+
+ if args.datasetinfos is None:
+ args.datasetinfos = ["train", "unbalanced_train", "balanced_train"]
+ if args.dataset_type == "webdataset":
+ args.train_data = get_tar_path_from_dataset_name(
+ args.datasetnames,
+ args.datasetinfos,
+ islocal=not args.remotedata,
+ proportion=args.dataset_proportion,
+ dataset_path=args.datasetpath,
+ full_dataset=args.full_train_dataset,
+ )
+
+ if args.full_train_dataset is None:
+ args.full_train_dataset = []
+ if args.exclude_eval_dataset is None:
+ args.exclude_eval_dataset = []
+ excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset
+
+ val_dataset_names = (
+ [n for n in args.datasetnames if n not in excluded_eval_datasets]
+ if excluded_eval_datasets
+ else args.datasetnames
+ )
+ args.val_dataset_names = val_dataset_names
+ args.val_data = get_tar_path_from_dataset_name(
+ val_dataset_names,
+ ["valid", "test", "eval"],
+ islocal=not args.remotedata,
+ proportion=1,
+ dataset_path=args.datasetpath,
+ full_dataset=None,
+ )
+
+ if args.train_data:
+ data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
+ args, model_cfg, is_train=True
+ )
+
+ if args.val_data:
+ data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
+ args, model_cfg, is_train=False
+ )
+
+ return data
diff --git a/audioldm2/clap/training/params.py b/audioldm2/clap/training/params.py
new file mode 100755
index 0000000000000000000000000000000000000000..0cc1a0e2d982e900988cf5a4b24b2e59b093537b
--- /dev/null
+++ b/audioldm2/clap/training/params.py
@@ -0,0 +1,563 @@
+import argparse
+
+
+def get_default_params(model_name):
+ # Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
+ model_name = model_name.lower()
+ if "vit" in model_name:
+ return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
+ else:
+ return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--train-data",
+ type=str,
+ default=None,
+ help="Path to h5 filewith training data",
+ )
+ parser.add_argument(
+ "--val-data",
+ type=str,
+ default=None,
+ help="Path to h5 file with validation data",
+ )
+ parser.add_argument(
+ "--freeze-text",
+ default=False,
+ action="store_true",
+ help="if you need to freeze the text encoder, make this True",
+ )
+ parser.add_argument(
+ "--freeze-text-after",
+ type=int,
+ default=-1,
+ help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it",
+ )
+ parser.add_argument(
+ "--train-ipc",
+ type=str,
+ default=None,
+ help="Path to npy file of the number of instance per class in training data",
+ )
+ parser.add_argument(
+ "--val-ipc",
+ type=str,
+ default=None,
+ help="Path to npy file of the number of instance per class in validation data",
+ )
+ parser.add_argument(
+ "--train-num-samples",
+ type=int,
+ default=None,
+ help="Number of samples in dataset. Required for webdataset if not available in info file.",
+ )
+ parser.add_argument(
+ "--val-num-samples",
+ type=int,
+ default=None,
+ help="Number of samples in dataset. Useful for webdataset if not available in info file.",
+ )
+ parser.add_argument(
+ "--dataset-type",
+ choices=["webdataset", "csv", "auto", "toy"],
+ default="auto",
+ help="Which type of dataset to process.",
+ )
+ parser.add_argument(
+ "--csv-separator",
+ type=str,
+ default="\t",
+ help="For csv-like datasets, which separator to use.",
+ )
+ parser.add_argument(
+ "--csv-img-key",
+ type=str,
+ default="filepath",
+ help="For csv-like datasets, the name of the key for the image paths.",
+ )
+ parser.add_argument(
+ "--csv-caption-key",
+ type=str,
+ default="title",
+ help="For csv-like datasets, the name of the key for the captions.",
+ )
+ parser.add_argument(
+ "--imagenet-val",
+ type=str,
+ default=None,
+ help="Path to imagenet val set for conducting zero shot evaluation.",
+ )
+ parser.add_argument(
+ "--imagenet-v2",
+ type=str,
+ default=None,
+ help="Path to imagenet v2 for conducting zero shot evaluation.",
+ )
+ parser.add_argument(
+ "--datasetnames",
+ nargs="+",
+ default=None,
+ help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects",
+ )
+ parser.add_argument(
+ "--full-train-dataset",
+ nargs="+",
+ default=None,
+ help="Which dataset will be trained with all the subsets. (train+test)",
+ )
+ parser.add_argument(
+ "--exclude-eval-dataset",
+ nargs="+",
+ default=None,
+ help="Which dataset will be excluded with evaluation",
+ )
+ parser.add_argument(
+ "--datasetinfos",
+ nargs="+",
+ default=None,
+ help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval",
+ )
+ parser.add_argument(
+ "--dataset-proportion",
+ type=float,
+ default=1.0,
+ help="How much proportion of dataset we want to train.",
+ )
+ parser.add_argument(
+ "--remotedata",
+ default=False,
+ action="store_true",
+ help="if the dataset is remote, set this flag",
+ )
+ parser.add_argument(
+ "--class-label-path",
+ type=str,
+ default=None,
+ help="The path of the class label pickle or csv.",
+ )
+ parser.add_argument(
+ "--datasetpath",
+ type=str,
+ default="/mnt/audio_clip/webdataset_tar",
+ help="The path to the dataset",
+ )
+ parser.add_argument(
+ "--logs",
+ type=str,
+ default="./logs/",
+ help="Where to store tensorboard logs. Use None to avoid storing logs.",
+ )
+ parser.add_argument(
+ "--log-local",
+ action="store_true",
+ default=False,
+ help="log files on local master, otherwise global master only.",
+ )
+ parser.add_argument(
+ "--name",
+ type=str,
+ default=None,
+ help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
+ )
+ parser.add_argument(
+ "--workers", type=int, default=1, help="Number of workers per GPU."
+ )
+ parser.add_argument(
+ "--batch-size", type=int, default=64, help="Batch size per GPU."
+ )
+ parser.add_argument(
+ "--epochs", type=int, default=32, help="Number of epochs to train for."
+ )
+ parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
+ parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
+ parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
+ parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
+ parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.")
+ parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
+
+ parser.add_argument(
+ "--split-opt",
+ action="store_true",
+ default=False,
+ help="Use this flag to skip the learning rate decay.",
+ )
+ parser.add_argument(
+ "--lr-pretrained", type=float, default=None, help="Learning rate for text."
+ )
+ parser.add_argument(
+ "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text."
+ )
+ parser.add_argument(
+ "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text."
+ )
+ parser.add_argument(
+ "--eps-pretrained", type=float, default=None, help="Adam epsilon for text."
+ )
+ parser.add_argument(
+ "--wd-pretrained", type=float, default=0.2, help="Weight decay for text."
+ )
+ parser.add_argument(
+ "--momentum-pretrained", type=float, default=0.9, help="Momentum for text."
+ )
+ parser.add_argument(
+ "--lr-new", type=float, default=None, help="Learning rate for audio."
+ )
+ parser.add_argument(
+ "--beta1-new", type=float, default=None, help="Adam beta 1 for audio."
+ )
+ parser.add_argument(
+ "--beta2-new", type=float, default=None, help="Adam beta 2 for audio."
+ )
+ parser.add_argument(
+ "--eps-new", type=float, default=None, help="Adam epsilon for audio."
+ )
+ parser.add_argument(
+ "--wd-new", type=float, default=0.2, help="Weight decay for audio."
+ )
+ parser.add_argument(
+ "--momentum-new", type=float, default=0.9, help="Momentum for audio."
+ )
+ parser.add_argument(
+ "--warmup", type=int, default=10000, help="Number of steps to warmup for."
+ )
+ parser.add_argument(
+ "--use-bn-sync",
+ default=False,
+ action="store_true",
+ help="Whether to use batch norm sync.",
+ )
+ parser.add_argument(
+ "--skip-scheduler",
+ action="store_true",
+ default=False,
+ help="Use this flag to skip the learning rate decay.",
+ )
+ parser.add_argument(
+ "--save-frequency", type=int, default=1, help="How often to save checkpoints."
+ )
+ parser.add_argument(
+ "--save-top-performance",
+ type=int,
+ default=0,
+ help="Save the top x performance weights if the value >0",
+ )
+ parser.add_argument(
+ "--save-most-recent",
+ action="store_true",
+ default=False,
+ help="Always save the most recent model trained to epoch_latest.pt.",
+ )
+ parser.add_argument(
+ "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
+ )
+ parser.add_argument(
+ "--val-frequency",
+ type=int,
+ default=1,
+ help="How often to run evaluation with val data.",
+ )
+ parser.add_argument(
+ "--resume",
+ default=None,
+ type=str,
+ help="path to latest checkpoint (default: none)",
+ )
+ parser.add_argument(
+ "--precision",
+ choices=["amp", "fp16", "fp32"],
+ default="amp",
+ help="Floating point precision.",
+ )
+ parser.add_argument(
+ "--amodel",
+ type=str,
+ default="RN50",
+ help="Name of the audio backbone to use.",
+ )
+ parser.add_argument(
+ "--tmodel",
+ type=str,
+ default="transformer",
+ help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]",
+ )
+ parser.add_argument(
+ "--pretrained-audio",
+ default="",
+ type=str,
+ help="Use a pretrained audio model weights for the audio encoder of CLAP",
+ )
+ parser.add_argument(
+ "--pretrained-text",
+ default="",
+ type=str,
+ help="Use a pretrained text model weights for the text encoder of CLAP",
+ )
+ parser.add_argument(
+ "--pretrained",
+ default="",
+ type=str,
+ help="Use a pretrained CLIP model weights with the specified tag or file path.",
+ )
+ parser.add_argument(
+ "--pretrained-image",
+ default=False,
+ action="store_true",
+ help="Load imagenet pretrained weights for image tower backbone if available.",
+ )
+ parser.add_argument(
+ "--lock-image",
+ default=False,
+ action="store_true",
+ help="Lock full image tower by disabling gradients.",
+ )
+ parser.add_argument(
+ "--lock-image-unlocked-groups",
+ type=int,
+ default=0,
+ help="Leave last n image tower layer groups unlocked.",
+ )
+ parser.add_argument(
+ "--lock-image-freeze-bn-stats",
+ default=False,
+ action="store_true",
+ help="Freeze BatchNorm running stats in image tower for any locked layers.",
+ )
+ parser.add_argument(
+ "--local-loss",
+ default=False,
+ action="store_true",
+ help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)",
+ )
+ parser.add_argument(
+ "--gather-with-grad",
+ default=False,
+ action="store_true",
+ help="enable full distributed gradient for feature gather",
+ )
+ parser.add_argument(
+ "--force-quick-gelu",
+ default=False,
+ action="store_true",
+ help="Force use of QuickGELU activation for non-OpenAI transformer models.",
+ )
+ parser.add_argument(
+ "--torchscript",
+ default=False,
+ action="store_true",
+ help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'",
+ )
+ parser.add_argument(
+ "--trace",
+ default=False,
+ action="store_true",
+ help="torch.jit.trace the model for inference / eval only",
+ )
+ # arguments for distributed training
+ parser.add_argument(
+ "--dist-url",
+ default="env://",
+ type=str,
+ help="url used to set up distributed training",
+ )
+ parser.add_argument(
+ "--dist-backend", default="nccl", type=str, help="distributed backend"
+ )
+ parser.add_argument(
+ "--report-to",
+ default="",
+ type=str,
+ help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']",
+ )
+ parser.add_argument(
+ "--wandb-notes", default="", type=str, help="Notes if logging with wandb"
+ )
+ parser.add_argument(
+ "--C", type=float, default=3.16, help="inverse regularizer for logistic reg."
+ )
+ parser.add_argument(
+ "--debug",
+ default=False,
+ action="store_true",
+ help="If true, more information is logged.",
+ )
+ parser.add_argument(
+ "--copy-codebase",
+ default=False,
+ action="store_true",
+ help="If true, we copy the entire base on the log diretory, and execute from there.",
+ )
+ parser.add_argument(
+ "--horovod",
+ default=False,
+ action="store_true",
+ help="Use horovod for distributed training.",
+ )
+ parser.add_argument(
+ "--ddp-static-graph",
+ default=False,
+ action="store_true",
+ help="Enable static graph optimization for DDP in PyTorch >= 1.11.",
+ )
+ parser.add_argument(
+ "--no-set-device-rank",
+ default=False,
+ action="store_true",
+ help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
+ )
+ parser.add_argument("--seed", type=int, default=4242, help="Default random seed.")
+
+ parser.add_argument(
+ "--top-k-checkpoint-select-dataset",
+ type=str,
+ default="all",
+ help="The dataset of selecting top-k checkpoint.",
+ )
+
+ # @R10, @R@5, @R1, mAP@10
+ parser.add_argument(
+ "--top-k-checkpoint-select-metric",
+ type=str,
+ default="_R@10",
+ help="The metric for selecting top-k checkpoint.",
+ )
+ parser.add_argument(
+ "--openai-model-cache-dir",
+ type=str,
+ default="~/.cache/clip",
+ help="Directory to download OpenAI models.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="adamw",
+ help="can be AdamW or SGD",
+ )
+ parser.add_argument(
+ "--parallel-eval",
+ default=False,
+ action="store_true",
+ help="Eval in parallel (multi-GPU, multi-node).",
+ )
+
+ parser.add_argument(
+ "--no-eval",
+ default=False,
+ action="store_true",
+ help="Training without evaluation.",
+ )
+
+ parser.add_argument(
+ "--lp-mlp",
+ default=False,
+ action="store_true",
+ help="Linear Probe using MLP layer or not.",
+ )
+
+ parser.add_argument(
+ "--lp-freeze",
+ default=False,
+ action="store_true",
+ help="Linear Probe using Freeze CLAP or not",
+ )
+
+ parser.add_argument(
+ "--lp-act",
+ default="None",
+ type=str,
+ help="Options are ['relu','elu','prelu','softmax','sigmoid']",
+ )
+
+ parser.add_argument(
+ "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe."
+ )
+
+ parser.add_argument(
+ "--lp-metrics",
+ type=str,
+ default="map,mauc,acc",
+ help="Metrics of Linear Probe.",
+ )
+
+ parser.add_argument(
+ "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe"
+ )
+ parser.add_argument(
+ "--kappa",
+ type=float,
+ default=0,
+ help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss",
+ )
+
+ parser.add_argument(
+ "--data-filling",
+ type=str,
+ default="pad",
+ help="type of data filling when the audio length is shorter than the max length."
+ "Can be one of the following: repeat, repeatpad, pad",
+ )
+ parser.add_argument(
+ "--data-truncating",
+ type=str,
+ default="rand_trunc",
+ help="type of data truncation when the audio length is longer than the max length."
+ "Can be one of the following: rand_trunc, fusion",
+ )
+
+ parser.add_argument(
+ "--clap-mlploss",
+ default=False,
+ action="store_true",
+ help="Using MLP loss for CLAP model or not",
+ )
+
+ parser.add_argument(
+ "--wandb-id",
+ type=str,
+ default=None,
+ help="the id of wandb experiment to restore.",
+ )
+
+ parser.add_argument(
+ "--sleep", type=float, default=0, help="sleep n seconds before start training"
+ )
+
+ # variable length processing
+ parser.add_argument(
+ "--enable-fusion",
+ default=False,
+ action="store_true",
+ help="Enable feature funsion for variable-length data",
+ )
+
+ parser.add_argument(
+ "--fusion-type",
+ type=str,
+ default="None",
+ help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']",
+ )
+
+ parser.add_argument(
+ "--mixup",
+ default=False,
+ action="store_true",
+ help="Enable mixup in finetuning training.",
+ )
+ parser.add_argument(
+ "--text-augment-selection",
+ type=str,
+ default=None,
+ help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']",
+ )
+
+ args = parser.parse_args()
+
+ # If some params are not passed, we use the default values based on model name.
+ default_params = get_default_params(args.amodel)
+ for name, val in default_params.items():
+ if getattr(args, name) is None:
+ setattr(args, name, val)
+
+ return args
diff --git a/audioldm2/hifigan/LICENSE b/audioldm2/hifigan/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..5afae394d6b37da0e12ba6b290d2512687f421ac
--- /dev/null
+++ b/audioldm2/hifigan/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 Jungil Kong
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/audioldm2/hifigan/__init__.py b/audioldm2/hifigan/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..34e055557bf2ecb457376663b67390543c71fb1f
--- /dev/null
+++ b/audioldm2/hifigan/__init__.py
@@ -0,0 +1,8 @@
+from .models_v2 import Generator
+from .models import Generator as Generator_old
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
diff --git a/audioldm2/hifigan/models.py b/audioldm2/hifigan/models.py
new file mode 100755
index 0000000000000000000000000000000000000000..c4382cc39de0463f9b7c0f33f037dbc233e7cb36
--- /dev/null
+++ b/audioldm2/hifigan/models.py
@@ -0,0 +1,174 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn.utils import weight_norm, remove_weight_norm
+
+LRELU_SLOPE = 0.1
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+class ResBlock(torch.nn.Module):
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(ResBlock, self).__init__()
+ self.h = h
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class Generator(torch.nn.Module):
+ def __init__(self, h):
+ super(Generator, self).__init__()
+ self.h = h
+ self.num_kernels = len(h.resblock_kernel_sizes)
+ self.num_upsamples = len(h.upsample_rates)
+ self.conv_pre = weight_norm(
+ Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
+ )
+ resblock = ResBlock
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ h.upsample_initial_channel // (2**i),
+ h.upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
+ for j, (k, d) in enumerate(
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
+ ):
+ self.resblocks.append(resblock(h, ch, k, d))
+
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x):
+ x = self.conv_pre(x)
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ x = self.ups[i](x)
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ # print("Removing weight norm...")
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
diff --git a/audioldm2/hifigan/models_v2.py b/audioldm2/hifigan/models_v2.py
new file mode 100755
index 0000000000000000000000000000000000000000..27a2df6b54bdd3a5b259645442624800ac0e8afe
--- /dev/null
+++ b/audioldm2/hifigan/models_v2.py
@@ -0,0 +1,395 @@
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn.utils import weight_norm, remove_weight_norm
+
+LRELU_SLOPE = 0.1
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+class ResBlock1(torch.nn.Module):
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(ResBlock1, self).__init__()
+ self.h = h
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
+ super(ResBlock2, self).__init__()
+ self.h = h
+ self.convs = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ ]
+ )
+ self.convs.apply(init_weights)
+
+ def forward(self, x):
+ for c in self.convs:
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+class Generator(torch.nn.Module):
+ def __init__(self, h):
+ super(Generator, self).__init__()
+ self.h = h
+ self.num_kernels = len(h.resblock_kernel_sizes)
+ self.num_upsamples = len(h.upsample_rates)
+ self.conv_pre = weight_norm(
+ Conv1d(256, h.upsample_initial_channel, 7, 1, padding=3)
+ )
+ resblock = ResBlock1 if h.resblock == "1" else ResBlock2
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ h.upsample_initial_channel // (2**i),
+ h.upsample_initial_channel // (2 ** (i + 1)),
+ u * 2,
+ u,
+ padding=u // 2 + u % 2,
+ output_padding=u % 2,
+ )
+ )
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
+ for j, (k, d) in enumerate(
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
+ ):
+ self.resblocks.append(resblock(h, ch, k, d))
+
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x):
+ # import ipdb; ipdb.set_trace()
+ x = self.conv_pre(x)
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ x = self.ups[i](x)
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ # print('Removing weight norm...')
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+
+
+##################################################################################################
+
+# import torch
+# import torch.nn as nn
+# import torch.nn.functional as F
+# from torch.nn import Conv1d, ConvTranspose1d
+# from torch.nn.utils import weight_norm, remove_weight_norm
+
+# LRELU_SLOPE = 0.1
+
+
+# def init_weights(m, mean=0.0, std=0.01):
+# classname = m.__class__.__name__
+# if classname.find("Conv") != -1:
+# m.weight.data.normal_(mean, std)
+
+
+# def get_padding(kernel_size, dilation=1):
+# return int((kernel_size * dilation - dilation) / 2)
+
+
+# class ResBlock(torch.nn.Module):
+# def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
+# super(ResBlock, self).__init__()
+# self.h = h
+# self.convs1 = nn.ModuleList(
+# [
+# weight_norm(
+# Conv1d(
+# channels,
+# channels,
+# kernel_size,
+# 1,
+# dilation=dilation[0],
+# padding=get_padding(kernel_size, dilation[0]),
+# )
+# ),
+# weight_norm(
+# Conv1d(
+# channels,
+# channels,
+# kernel_size,
+# 1,
+# dilation=dilation[1],
+# padding=get_padding(kernel_size, dilation[1]),
+# )
+# ),
+# weight_norm(
+# Conv1d(
+# channels,
+# channels,
+# kernel_size,
+# 1,
+# dilation=dilation[2],
+# padding=get_padding(kernel_size, dilation[2]),
+# )
+# ),
+# ]
+# )
+# self.convs1.apply(init_weights)
+
+# self.convs2 = nn.ModuleList(
+# [
+# weight_norm(
+# Conv1d(
+# channels,
+# channels,
+# kernel_size,
+# 1,
+# dilation=1,
+# padding=get_padding(kernel_size, 1),
+# )
+# ),
+# weight_norm(
+# Conv1d(
+# channels,
+# channels,
+# kernel_size,
+# 1,
+# dilation=1,
+# padding=get_padding(kernel_size, 1),
+# )
+# ),
+# weight_norm(
+# Conv1d(
+# channels,
+# channels,
+# kernel_size,
+# 1,
+# dilation=1,
+# padding=get_padding(kernel_size, 1),
+# )
+# ),
+# ]
+# )
+# self.convs2.apply(init_weights)
+
+# def forward(self, x):
+# for c1, c2 in zip(self.convs1, self.convs2):
+# xt = F.leaky_relu(x, LRELU_SLOPE)
+# xt = c1(xt)
+# xt = F.leaky_relu(xt, LRELU_SLOPE)
+# xt = c2(xt)
+# x = xt + x
+# return x
+
+# def remove_weight_norm(self):
+# for l in self.convs1:
+# remove_weight_norm(l)
+# for l in self.convs2:
+# remove_weight_norm(l)
+
+# class Generator(torch.nn.Module):
+# def __init__(self, h):
+# super(Generator, self).__init__()
+# self.h = h
+# self.num_kernels = len(h.resblock_kernel_sizes)
+# self.num_upsamples = len(h.upsample_rates)
+# self.conv_pre = weight_norm(
+# Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
+# )
+# resblock = ResBlock
+
+# self.ups = nn.ModuleList()
+# for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+# self.ups.append(
+# weight_norm(
+# ConvTranspose1d(
+# h.upsample_initial_channel // (2**i),
+# h.upsample_initial_channel // (2 ** (i + 1)),
+# k,
+# u,
+# padding=(k - u) // 2,
+# )
+# )
+# )
+
+# self.resblocks = nn.ModuleList()
+# for i in range(len(self.ups)):
+# ch = h.upsample_initial_channel // (2 ** (i + 1))
+# for j, (k, d) in enumerate(
+# zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
+# ):
+# self.resblocks.append(resblock(h, ch, k, d))
+
+# self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+# self.ups.apply(init_weights)
+# self.conv_post.apply(init_weights)
+
+# def forward(self, x):
+# x = self.conv_pre(x)
+# for i in range(self.num_upsamples):
+# x = F.leaky_relu(x, LRELU_SLOPE)
+# x = self.ups[i](x)
+# xs = None
+# for j in range(self.num_kernels):
+# if xs is None:
+# xs = self.resblocks[i * self.num_kernels + j](x)
+# else:
+# xs += self.resblocks[i * self.num_kernels + j](x)
+# x = xs / self.num_kernels
+# x = F.leaky_relu(x)
+# x = self.conv_post(x)
+# x = torch.tanh(x)
+
+# return x
+
+# def remove_weight_norm(self):
+# print("Removing weight norm...")
+# for l in self.ups:
+# remove_weight_norm(l)
+# for l in self.resblocks:
+# l.remove_weight_norm()
+# remove_weight_norm(self.conv_pre)
+# remove_weight_norm(self.conv_post)
diff --git a/audioldm2/latent_diffusion/__init__.py b/audioldm2/latent_diffusion/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/latent_diffusion/models/__init__.py b/audioldm2/latent_diffusion/models/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/latent_diffusion/models/ddim.py b/audioldm2/latent_diffusion/models/ddim.py
new file mode 100755
index 0000000000000000000000000000000000000000..0c07207af7959847552805f00831122304b4330e
--- /dev/null
+++ b/audioldm2/latent_diffusion/models/ddim.py
@@ -0,0 +1,487 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from audioldm2.latent_diffusion.modules.diffusionmodules.util import (
+ make_ddim_sampling_parameters,
+ make_ddim_timesteps,
+ noise_like,
+ extract_into_tensor,
+)
+
+
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+ self.device = device
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != self.device:
+ attr = attr.to(self.device)
+ setattr(self, name, attr)
+
+ def make_schedule(
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
+ ):
+ self.ddim_timesteps = make_ddim_timesteps(
+ ddim_discr_method=ddim_discretize,
+ num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
+ verbose=verbose,
+ )
+ alphas_cumprod = self.model.alphas_cumprod
+ assert (
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
+ ), "alphas have to be defined for each timestep"
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer("betas", to_torch(self.model.betas))
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+ self.register_buffer(
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
+ )
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer(
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_one_minus_alphas_cumprod",
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
+ )
+ self.register_buffer(
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod",
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
+ )
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
+ alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,
+ verbose=verbose,
+ )
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
+ self.register_buffer("ddim_alphas", ddim_alphas)
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev)
+ / (1 - self.alphas_cumprod)
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
+ )
+ self.register_buffer(
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
+ )
+
+ @torch.no_grad()
+ def sample(
+ self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.0,
+ mask=None,
+ x0=None,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ ucg_schedule=None,
+ **kwargs,
+ ):
+ # if conditioning is not None:
+ # if isinstance(conditioning, dict):
+ # ctmp = conditioning[list(conditioning.keys())[0]]
+ # while isinstance(ctmp, list): ctmp = ctmp[0]
+ # cbs = ctmp.shape[0]
+ # if cbs != batch_size:
+ # print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ # elif isinstance(conditioning, list):
+ # for ctmp in conditioning:
+ # if ctmp.shape[0] != batch_size:
+ # print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ # else:
+ # if conditioning.shape[0] != batch_size:
+ # print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(
+ conditioning,
+ size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask,
+ x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ ucg_schedule=ucg_schedule,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(
+ self,
+ cond,
+ shape,
+ x_T=None,
+ ddim_use_original_steps=False,
+ callback=None,
+ timesteps=None,
+ quantize_denoised=False,
+ mask=None,
+ x0=None,
+ img_callback=None,
+ log_every_t=100,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ dynamic_threshold=None,
+ ucg_schedule=None,
+ ):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = (
+ self.ddpm_num_timesteps
+ if ddim_use_original_steps
+ else self.ddim_timesteps
+ )
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = (
+ int(
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
+ * self.ddim_timesteps.shape[0]
+ )
+ - 1
+ )
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
+ time_range = (
+ reversed(range(0, timesteps))
+ if ddim_use_original_steps
+ else np.flip(timesteps)
+ )
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(
+ x0, ts
+ ) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1.0 - mask) * img
+
+ if ucg_schedule is not None:
+ assert len(ucg_schedule) == len(time_range)
+ unconditional_guidance_scale = ucg_schedule[i]
+
+ outs = self.p_sample_ddim(
+ img,
+ cond,
+ ts,
+ index=index,
+ use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised,
+ temperature=temperature,
+ noise_dropout=noise_dropout,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ )
+ img, pred_x0 = outs
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates["x_inter"].append(img)
+ intermediates["pred_x0"].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim(
+ self,
+ x,
+ c,
+ t,
+ index,
+ repeat_noise=False,
+ use_original_steps=False,
+ quantize_denoised=False,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ dynamic_threshold=None,
+ ):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
+ model_output = self.model.apply_model(x, t, c)
+ else:
+ x_in = x
+ t_in = t
+
+ assert isinstance(c, dict)
+ assert isinstance(unconditional_conditioning, dict)
+
+ model_uncond = self.model.apply_model(
+ x_in, t_in, unconditional_conditioning
+ )
+ model_t = self.model.apply_model(x_in, t_in, c)
+
+ model_output = model_uncond + unconditional_guidance_scale * (
+ model_t - model_uncond
+ )
+
+ if self.model.parameterization == "v":
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+ else:
+ e_t = model_output
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps", "not implemented"
+ e_t = score_corrector.modify_score(
+ self.model, e_t, x, t, c, **corrector_kwargs
+ )
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = (
+ self.model.alphas_cumprod_prev
+ if use_original_steps
+ else self.ddim_alphas_prev
+ )
+ sqrt_one_minus_alphas = (
+ self.model.sqrt_one_minus_alphas_cumprod
+ if use_original_steps
+ else self.ddim_sqrt_one_minus_alphas
+ )
+ sigmas = (
+ self.model.ddim_sigmas_for_original_num_steps
+ if use_original_steps
+ else self.ddim_sigmas
+ )
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full(
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
+ )
+
+ # current prediction for x_0
+ if self.model.parameterization != "v":
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ else:
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+ if dynamic_threshold is not None:
+ raise NotImplementedError()
+
+ # direction pointing to x_t
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.0:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ @torch.no_grad()
+ def encode(
+ self,
+ x0,
+ c,
+ t_enc,
+ use_original_steps=False,
+ return_intermediates=None,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ callback=None,
+ ):
+ num_reference_steps = (
+ self.ddpm_num_timesteps
+ if use_original_steps
+ else self.ddim_timesteps.shape[0]
+ )
+
+ assert t_enc <= num_reference_steps
+ num_steps = t_enc
+
+ if use_original_steps:
+ alphas_next = self.alphas_cumprod[:num_steps]
+ alphas = self.alphas_cumprod_prev[:num_steps]
+ else:
+ alphas_next = self.ddim_alphas[:num_steps]
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+ x_next = x0
+ intermediates = []
+ inter_steps = []
+ for i in tqdm(range(num_steps), desc="Encoding Image"):
+ t = torch.full(
+ (x0.shape[0],), i, device=self.model.device, dtype=torch.long
+ )
+ if unconditional_guidance_scale == 1.0:
+ noise_pred = self.model.apply_model(x_next, t, c)
+ else:
+ assert unconditional_conditioning is not None
+ e_t_uncond, noise_pred = torch.chunk(
+ self.model.apply_model(
+ torch.cat((x_next, x_next)),
+ torch.cat((t, t)),
+ torch.cat((unconditional_conditioning, c)),
+ ),
+ 2,
+ )
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (
+ noise_pred - e_t_uncond
+ )
+
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+ weighted_noise_pred = (
+ alphas_next[i].sqrt()
+ * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
+ * noise_pred
+ )
+ x_next = xt_weighted + weighted_noise_pred
+ if (
+ return_intermediates
+ and i % (num_steps // return_intermediates) == 0
+ and i < num_steps - 1
+ ):
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ elif return_intermediates and i >= num_steps - 2:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ if callback:
+ callback(i)
+
+ out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
+ if return_intermediates:
+ out.update({"intermediates": intermediates})
+ return x_next, out
+
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+ if noise is None:
+ noise = torch.randn_like(x0)
+ return (
+ extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
+ )
+
+ @torch.no_grad()
+ def decode(
+ self,
+ x_latent,
+ cond,
+ t_start,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ use_original_steps=False,
+ callback=None,
+ ):
+ timesteps = (
+ np.arange(self.ddpm_num_timesteps)
+ if use_original_steps
+ else self.ddim_timesteps
+ )
+ timesteps = timesteps[:t_start]
+
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full(
+ (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
+ )
+ x_dec, _ = self.p_sample_ddim(
+ x_dec,
+ cond,
+ ts,
+ index=index,
+ use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ if callback:
+ callback(i)
+ return x_dec
diff --git a/audioldm2/latent_diffusion/models/ddpm.py b/audioldm2/latent_diffusion/models/ddpm.py
new file mode 100755
index 0000000000000000000000000000000000000000..df3a6c032ba2ec61250212a31d68184e763dcf0e
--- /dev/null
+++ b/audioldm2/latent_diffusion/models/ddpm.py
@@ -0,0 +1,1840 @@
+from multiprocessing.sharedctypes import Value
+import os
+
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange, repeat
+from contextlib import contextmanager
+from functools import partial
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from audioldm2.latent_diffusion.modules.encoders.modules import *
+
+from audioldm2.latent_diffusion.util import (
+ exists,
+ default,
+ count_params,
+ instantiate_from_config,
+)
+from audioldm2.latent_diffusion.modules.ema import LitEma
+from audioldm2.latent_diffusion.modules.distributions.distributions import (
+ DiagonalGaussianDistribution,
+)
+
+# from latent_encoder.autoencoder import (
+# VQModelInterface,
+# IdentityFirstStage,
+# AutoencoderKL,
+# )
+
+from audioldm2.latent_diffusion.modules.diffusionmodules.util import (
+ make_beta_schedule,
+ extract_into_tensor,
+ noise_like,
+)
+
+from audioldm2.latent_diffusion.models.ddim import DDIMSampler
+from audioldm2.latent_diffusion.models.plms import PLMSSampler
+import soundfile as sf
+import os
+
+__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
+
+CACHE_DIR = os.getenv(
+ "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm2")
+)
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(nn.Module):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(
+ self,
+ unet_config,
+ sampling_rate=None,
+ timesteps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ ckpt_path=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ latent_t_size=256,
+ latent_f_size=16,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.0,
+ v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.0,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.0,
+ evaluator=None,
+ device=None,
+ ):
+ super().__init__()
+ assert parameterization in [
+ "eps",
+ "x0",
+ "v",
+ ], 'currently only supporting "eps" and "x0" and "v"'
+ self.parameterization = parameterization
+ self.state = None
+ self.device = device
+ # print(
+ # f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
+ # )
+ assert sampling_rate is not None
+ self.validation_folder_name = "temp_name"
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.sampling_rate = sampling_rate
+
+ self.clap = CLAPAudioEmbeddingClassifierFreev2(
+ pretrained_path="",
+ sampling_rate=self.sampling_rate,
+ embed_mode="audio",
+ amodel="HTSAT-base",
+ )
+
+ self.initialize_param_check_toolkit()
+
+ self.latent_t_size = latent_t_size
+ self.latent_f_size = latent_f_size
+
+ self.channels = channels
+ self.use_positional_encodings = use_positional_encodings
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
+ count_params(self.model, verbose=True)
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+
+ if monitor is not None:
+ self.monitor = monitor
+ if ckpt_path is not None:
+ self.init_from_ckpt(
+ ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet
+ )
+
+ self.register_schedule(
+ given_betas=given_betas,
+ beta_schedule=beta_schedule,
+ timesteps=timesteps,
+ linear_start=linear_start,
+ linear_end=linear_end,
+ cosine_s=cosine_s,
+ )
+
+ self.loss_type = loss_type
+
+ self.learn_logvar = learn_logvar
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+ else:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=False)
+
+ self.logger_save_dir = None
+ self.logger_exp_name = None
+ self.logger_exp_group_name = None
+ self.logger_version = None
+
+ self.label_indices_total = None
+ # To avoid the system cannot find metric value for checkpoint
+ self.metrics_buffer = {
+ "val/kullback_leibler_divergence_sigmoid": 15.0,
+ "val/kullback_leibler_divergence_softmax": 10.0,
+ "val/psnr": 0.0,
+ "val/ssim": 0.0,
+ "val/inception_score_mean": 1.0,
+ "val/inception_score_std": 0.0,
+ "val/kernel_inception_distance_mean": 0.0,
+ "val/kernel_inception_distance_std": 0.0,
+ "val/frechet_inception_distance": 133.0,
+ "val/frechet_audio_distance": 32.0,
+ }
+ self.initial_learning_rate = None
+ self.test_data_subset_path = None
+
+ def get_log_dir(self):
+ return os.path.join(
+ self.logger_save_dir, self.logger_exp_group_name, self.logger_exp_name
+ )
+
+ def set_log_dir(self, save_dir, exp_group_name, exp_name):
+ self.logger_save_dir = save_dir
+ self.logger_exp_group_name = exp_group_name
+ self.logger_exp_name = exp_name
+
+ def register_schedule(
+ self,
+ given_betas=None,
+ beta_schedule="linear",
+ timesteps=1000,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ ):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(
+ beta_schedule,
+ timesteps,
+ linear_start=linear_start,
+ linear_end=linear_end,
+ cosine_s=cosine_s,
+ )
+ alphas = 1.0 - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
+
+ (timesteps,) = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert (
+ alphas_cumprod.shape[0] == self.num_timesteps
+ ), "alphas have to be defined for each timestep"
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer("betas", to_torch(betas))
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer(
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
+ )
+ self.register_buffer(
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
+ )
+ self.register_buffer(
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
+ )
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
+ )
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (
+ 1.0 - alphas_cumprod_prev
+ ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer(
+ "posterior_log_variance_clipped",
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
+ )
+ self.register_buffer(
+ "posterior_mean_coef1",
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
+ )
+ self.register_buffer(
+ "posterior_mean_coef2",
+ to_torch(
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
+ ),
+ )
+
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas**2 / (
+ 2
+ * self.posterior_variance
+ * to_torch(alphas)
+ * (1 - self.alphas_cumprod)
+ )
+ elif self.parameterization == "x0":
+ lvlb_weights = (
+ 0.5
+ * np.sqrt(torch.Tensor(alphas_cumprod))
+ / (2.0 * 1 - torch.Tensor(alphas_cumprod))
+ )
+ elif self.parameterization == "v":
+ lvlb_weights = torch.ones_like(
+ self.betas**2
+ / (
+ 2
+ * self.posterior_variance
+ * to_torch(alphas)
+ * (1 - self.alphas_cumprod)
+ )
+ )
+ else:
+ raise NotImplementedError("mu not supported")
+ # TODO how to choose this term
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ # if context is not None:
+ # print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ # if context is not None:
+ # print(f"{context}: Restored training weights")
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = (
+ self.load_state_dict(sd, strict=False)
+ if not only_model
+ else self.model.load_state_dict(sd, strict=False)
+ )
+ print(
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
+ )
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
+ )
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+ * noise
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, t, clip_denoised: bool):
+ model_out = self.model(x, t)
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ if clip_denoised:
+ x_recon.clamp_(-1.0, 1.0)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
+ x_start=x_recon, x_t=x, t=t
+ )
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(
+ x=x, t=t, clip_denoised=clip_denoised
+ )
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (
+ (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous()
+ )
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, return_intermediates=False):
+ device = self.betas.device
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+ intermediates = [img]
+ for i in tqdm(
+ reversed(range(0, self.num_timesteps)),
+ desc="Sampling t",
+ total=self.num_timesteps,
+ ):
+ img = self.p_sample(
+ img,
+ torch.full((b,), i, device=device, dtype=torch.long),
+ clip_denoised=self.clip_denoised,
+ )
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+ intermediates.append(img)
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, batch_size=16, return_intermediates=False):
+ shape = (batch_size, channels, self.latent_t_size, self.latent_f_size)
+ self.channels
+ return self.p_sample_loop(shape, return_intermediates=return_intermediates)
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * noise
+ )
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == "l1":
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == "l2":
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction="none")
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def predict_start_from_z_and_v(self, x_t, t, v):
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+ )
+
+ def predict_eps_from_z_and_v(self, x_t, t, v):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
+ * x_t
+ )
+
+ def get_v(self, x, noise, t):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+ )
+
+ def forward(self, x, *args, **kwargs):
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+ t = torch.randint(
+ 0, self.num_timesteps, (x.shape[0],), device=self.device
+ ).long()
+ return self.p_losses(x, t, *args, **kwargs)
+
+ def get_input(self, batch, k):
+ # fbank, log_magnitudes_stft, label_indices, fname, waveform, clip_label, text = batch
+ # fbank, stft, label_indices, fname, waveform, text = batch
+ fname, text, waveform, stft, fbank = (
+ batch["fname"],
+ batch["text"],
+ batch["waveform"],
+ batch["stft"],
+ batch["log_mel_spec"],
+ )
+ # for i in range(fbank.size(0)):
+ # fb = fbank[i].numpy()
+ # seg_lb = seg_label[i].numpy()
+ # logits = np.mean(seg_lb, axis=0)
+ # index = np.argsort(logits)[::-1][:5]
+ # plt.imshow(seg_lb[:,index], aspect="auto")
+ # plt.title(index)
+ # plt.savefig("%s_label.png" % i)
+ # plt.close()
+ # plt.imshow(fb, aspect="auto")
+ # plt.savefig("%s_fb.png" % i)
+ # plt.close()
+ ret = {}
+
+ ret["fbank"] = (
+ fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float()
+ )
+ ret["stft"] = stft.to(memory_format=torch.contiguous_format).float()
+ # ret["clip_label"] = clip_label.to(memory
+ # _format=torch.contiguous_format).float()
+ ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float()
+ ret["text"] = list(text)
+ ret["fname"] = fname
+
+ for key in batch.keys():
+ if key not in ret.keys():
+ ret[key] = batch[key]
+
+ return ret[k]
+
+ def _get_rows_from_list(self, samples):
+ n_imgs_per_row = len(samples)
+ denoise_grid = rearrange(samples, "n b c h w -> b n c h w")
+ denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.first_stage_key)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+
+ # get diffusion row
+ diffusion_row = list()
+ x_start = x[:n_row]
+
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(x_start)
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ diffusion_row.append(x_noisy)
+
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, denoise_row = self.sample(
+ batch_size=N, return_intermediates=True
+ )
+
+ log["samples"] = samples
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.learn_logvar:
+ params = params + [self.logvar]
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+
+ def initialize_param_check_toolkit(self):
+ self.tracked_steps = 0
+ self.param_dict = {}
+
+ def statistic_require_grad_tensor_number(self, module, name=None):
+ requires_grad_num = 0
+ total_num = 0
+ require_grad_tensor = None
+ for p in module.parameters():
+ if p.requires_grad:
+ requires_grad_num += 1
+ if require_grad_tensor is None:
+ require_grad_tensor = p
+ total_num += 1
+ print(
+ "Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)"
+ % (name, requires_grad_num, total_num, requires_grad_num / total_num)
+ )
+ return require_grad_tensor
+
+
+class LatentDiffusion(DDPM):
+ """main class"""
+
+ def __init__(
+ self,
+ first_stage_config,
+ cond_stage_config=None,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ optimize_ddpm_parameter=True,
+ unconditional_prob_cfg=0.1,
+ warmup_steps=10000,
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ batchsize=None,
+ evaluation_params={},
+ scale_by_std=False,
+ base_learning_rate=None,
+ *args,
+ **kwargs,
+ ):
+ self.learning_rate = base_learning_rate
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ self.warmup_steps = warmup_steps
+
+ if optimize_ddpm_parameter:
+ if unconditional_prob_cfg == 0.0:
+ "You choose to optimize DDPM. The classifier free guidance scale should be 0.1"
+ unconditional_prob_cfg = 0.1
+ else:
+ if unconditional_prob_cfg == 0.1:
+ "You choose not to optimize DDPM. The classifier free guidance scale should be 0.0"
+ unconditional_prob_cfg = 0.0
+
+ self.evaluation_params = evaluation_params
+ assert self.num_timesteps_cond <= kwargs["timesteps"]
+
+ # for backwards compatibility after implementation of DiffusionWrapper
+ # if conditioning_key is None:
+ # conditioning_key = "concat" if concat_mode else "crossattn"
+ # if cond_stage_config == "__is_unconditional__":
+ # conditioning_key = None
+
+ conditioning_key = list(cond_stage_config.keys())
+
+ self.conditioning_key = conditioning_key
+
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+
+ self.optimize_ddpm_parameter = optimize_ddpm_parameter
+ # if(not optimize_ddpm_parameter):
+ # print("Warning: Close the optimization of the latent diffusion model")
+ # for p in self.model.parameters():
+ # p.requires_grad=False
+
+ self.concat_mode = concat_mode
+ self.cond_stage_key = cond_stage_key
+ self.cond_stage_key_orig = cond_stage_key
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer("scale_factor", torch.tensor(scale_factor))
+ self.instantiate_first_stage(first_stage_config)
+ self.unconditional_prob_cfg = unconditional_prob_cfg
+ self.cond_stage_models = nn.ModuleList([])
+ self.instantiate_cond_stage(cond_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+ self.conditional_dry_run_finished = False
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+
+ for each in self.cond_stage_models:
+ params = params + list(
+ each.parameters()
+ ) # Add the parameter from the conditional stage
+
+ if self.learn_logvar:
+ print("Diffusion model optimizing logvar")
+ params.append(self.logvar)
+ opt = torch.optim.AdamW(params, lr=lr)
+ # if self.use_scheduler:
+ # assert "target" in self.scheduler_config
+ # scheduler = instantiate_from_config(self.scheduler_config)
+
+ # print("Setting up LambdaLR scheduler...")
+ # scheduler = [
+ # {
+ # "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
+ # "interval": "step",
+ # "frequency": 1,
+ # }
+ # ]
+ # return [opt], scheduler
+ return opt
+
+ def make_cond_schedule(
+ self,
+ ):
+ self.cond_ids = torch.full(
+ size=(self.num_timesteps,),
+ fill_value=self.num_timesteps - 1,
+ dtype=torch.long,
+ )
+ ids = torch.round(
+ torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
+ ).long()
+ self.cond_ids[: self.num_timesteps_cond] = ids
+
+ @torch.no_grad()
+ def on_train_batch_start(self, batch, batch_idx):
+ # only for very first batch
+ if (
+ self.scale_factor == 1
+ and self.scale_by_std
+ and self.current_epoch == 0
+ and self.global_step == 0
+ and batch_idx == 0
+ and not self.restarted_from_ckpt
+ ):
+ # assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+ # set rescale weight to 1./std of encodings
+ print("### USING STD-RESCALING ###")
+ x = super().get_input(batch, self.first_stage_key)
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer("scale_factor", 1.0 / z.flatten().std())
+ print(f"setting self.scale_factor to {self.scale_factor}")
+ print("### USING STD-RESCALING ###")
+
+ def register_schedule(
+ self,
+ given_betas=None,
+ beta_schedule="linear",
+ timesteps=1000,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ ):
+ super().register_schedule(
+ given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
+ )
+
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def make_decision(self, probability):
+ if float(torch.rand(1)) < probability:
+ return True
+ else:
+ return False
+
+ def instantiate_cond_stage(self, config):
+ self.cond_stage_model_metadata = {}
+ for i, cond_model_key in enumerate(config.keys()):
+ model = instantiate_from_config(config[cond_model_key])
+ self.cond_stage_models.append(model)
+ self.cond_stage_model_metadata[cond_model_key] = {
+ "model_idx": i,
+ "cond_stage_key": config[cond_model_key]["cond_stage_key"],
+ "conditioning_key": config[cond_model_key]["conditioning_key"],
+ }
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
+ )
+ return self.scale_factor * z
+
+ def get_learned_conditioning(self, c, key, unconditional_cfg):
+ assert key in self.cond_stage_model_metadata.keys()
+
+ # Classifier-free guidance
+ if not unconditional_cfg:
+ c = self.cond_stage_models[
+ self.cond_stage_model_metadata[key]["model_idx"]
+ ](c)
+ else:
+ # when the cond_stage_key is "all", pick one random element out
+ if isinstance(c, dict):
+ c = c[list(c.keys())[0]]
+
+ if isinstance(c, torch.Tensor):
+ batchsize = c.size(0)
+ elif isinstance(c, list):
+ batchsize = len(c)
+ else:
+ raise NotImplementedError()
+
+ c = self.cond_stage_models[
+ self.cond_stage_model_metadata[key]["model_idx"]
+ ].get_unconditional_condition(batchsize)
+
+ return c
+
+ def get_input(
+ self,
+ batch,
+ k,
+ return_first_stage_encode=True,
+ return_decoding_output=False,
+ return_encoder_input=False,
+ return_encoder_output=False,
+ unconditional_prob_cfg=0.1,
+ ):
+ x = super().get_input(batch, k)
+
+ x = x.to(self.device)
+
+ if return_first_stage_encode:
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ else:
+ z = None
+ cond_dict = {}
+ if len(self.cond_stage_model_metadata.keys()) > 0:
+ unconditional_cfg = False
+ if self.conditional_dry_run_finished and self.make_decision(
+ unconditional_prob_cfg
+ ):
+ unconditional_cfg = True
+ for cond_model_key in self.cond_stage_model_metadata.keys():
+ cond_stage_key = self.cond_stage_model_metadata[cond_model_key][
+ "cond_stage_key"
+ ]
+
+ if cond_model_key in cond_dict.keys():
+ continue
+
+ if not self.training:
+ if isinstance(
+ self.cond_stage_models[
+ self.cond_stage_model_metadata[cond_model_key]["model_idx"]
+ ],
+ CLAPAudioEmbeddingClassifierFreev2,
+ ):
+ print(
+ "Warning: CLAP model normally should use text for evaluation"
+ )
+
+ # The original data for conditioning
+ # If cond_model_key is "all", that means the conditional model need all the information from a batch
+
+ if cond_stage_key != "all":
+ xc = super().get_input(batch, cond_stage_key)
+ if type(xc) == torch.Tensor:
+ xc = xc.to(self.device)
+ else:
+ xc = batch
+
+ # if cond_stage_key is "all", xc will be a dictionary containing all keys
+ # Otherwise xc will be an entry of the dictionary
+ c = self.get_learned_conditioning(
+ xc, key=cond_model_key, unconditional_cfg=unconditional_cfg
+ )
+
+ # cond_dict will be used to condition the diffusion model
+ # If one conditional model return multiple conditioning signal
+ if isinstance(c, dict):
+ for k in c.keys():
+ cond_dict[k] = c[k]
+ else:
+ cond_dict[cond_model_key] = c
+
+ # If the key is accidently added to the dictionary and not in the condition list, remove the condition
+ # for k in list(cond_dict.keys()):
+ # if(k not in self.cond_stage_model_metadata.keys()):
+ # del cond_dict[k]
+
+ out = [z, cond_dict]
+
+ if return_decoding_output:
+ xrec = self.decode_first_stage(z)
+ out += [xrec]
+
+ if return_encoder_input:
+ out += [x]
+
+ if return_encoder_output:
+ out += [encoder_posterior]
+
+ if not self.conditional_dry_run_finished:
+ self.conditional_dry_run_finished = True
+
+ # Output is a dictionary, where the value could only be tensor or tuple
+ return out
+
+ def decode_first_stage(self, z):
+ with torch.no_grad():
+ z = 1.0 / self.scale_factor * z
+ decoding = self.first_stage_model.decode(z)
+ return decoding
+
+ def mel_spectrogram_to_waveform(
+ self, mel, savepath=".", bs=None, name="outwav", save=True
+ ):
+ # Mel: [bs, 1, t-steps, fbins]
+ if len(mel.size()) == 4:
+ mel = mel.squeeze(1)
+ mel = mel.permute(0, 2, 1)
+ waveform = self.first_stage_model.vocoder(mel)
+ waveform = waveform.cpu().detach().numpy()
+ if save:
+ self.save_waveform(waveform, savepath, name)
+ return waveform
+
+ def encode_first_stage(self, x):
+ with torch.no_grad():
+ return self.first_stage_model.encode(x)
+
+ def extract_possible_loss_in_cond_dict(self, cond_dict):
+ # This function enable the conditional module to return loss function that can optimize them
+
+ assert isinstance(cond_dict, dict)
+ losses = {}
+
+ for cond_key in cond_dict.keys():
+ if "loss" in cond_key and "noncond" in cond_key:
+ assert cond_key not in losses.keys()
+ losses[cond_key] = cond_dict[cond_key]
+
+ return losses
+
+ def filter_useful_cond_dict(self, cond_dict):
+ new_cond_dict = {}
+ for key in cond_dict.keys():
+ if key in self.cond_stage_model_metadata.keys():
+ new_cond_dict[key] = cond_dict[key]
+
+ # All the conditional key in the metadata should be used
+ for key in self.cond_stage_model_metadata.keys():
+ assert key in new_cond_dict.keys(), "%s, %s" % (
+ key,
+ str(new_cond_dict.keys()),
+ )
+
+ return new_cond_dict
+
+ def shared_step(self, batch, **kwargs):
+ if self.training:
+ # Classifier-free guidance
+ unconditional_prob_cfg = self.unconditional_prob_cfg
+ else:
+ unconditional_prob_cfg = 0.0 # TODO possible bug here
+
+ x, c = self.get_input(
+ batch, self.first_stage_key, unconditional_prob_cfg=unconditional_prob_cfg
+ )
+
+ if self.optimize_ddpm_parameter:
+ loss, loss_dict = self(x, self.filter_useful_cond_dict(c))
+ else:
+ loss_dict = {}
+ loss = None
+
+ additional_loss_for_cond_modules = self.extract_possible_loss_in_cond_dict(c)
+ assert isinstance(additional_loss_for_cond_modules, dict)
+
+ loss_dict.update(additional_loss_for_cond_modules)
+
+ if len(additional_loss_for_cond_modules.keys()) > 0:
+ for k in additional_loss_for_cond_modules.keys():
+ if loss is None:
+ loss = additional_loss_for_cond_modules[k]
+ else:
+ loss = loss + additional_loss_for_cond_modules[k]
+
+ # for k,v in additional_loss_for_cond_modules.items():
+ # self.log(
+ # "cond_stage/"+k,
+ # float(v),
+ # prog_bar=True,
+ # logger=True,
+ # on_step=True,
+ # on_epoch=True,
+ # )
+ if self.training:
+ assert loss is not None
+
+ return loss, loss_dict
+
+ def forward(self, x, c, *args, **kwargs):
+ t = torch.randint(
+ 0, self.num_timesteps, (x.shape[0],), device=self.device
+ ).long()
+
+ # assert c is not None
+ # c = self.get_learned_conditioning(c)
+
+ loss, loss_dict = self.p_losses(x, c, t, *args, **kwargs)
+ return loss, loss_dict
+
+ def reorder_cond_dict(self, cond_dict):
+ # To make sure the order is correct
+ new_cond_dict = {}
+ for key in self.conditioning_key:
+ new_cond_dict[key] = cond_dict[key]
+ return new_cond_dict
+
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
+ cond = self.reorder_cond_dict(cond)
+
+ x_recon = self.model(x_noisy, t, cond_dict=cond)
+
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+
+ def p_losses(self, x_start, cond, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_output = self.apply_model(x_noisy, t, cond)
+
+ loss_dict = {}
+ prefix = "train" if self.training else "val"
+
+ if self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError()
+ # print(model_output.size(), target.size())
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+ loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()})
+
+ logvar_t = self.logvar[t].to(self.device)
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+ if self.learn_logvar:
+ loss_dict.update({f"{prefix}/loss_gamma": loss.mean()})
+ loss_dict.update({"logvar": self.logvar.data.mean()})
+
+ loss = self.l_simple_weight * loss.mean()
+
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict.update({f"{prefix}/loss_vlb": loss_vlb})
+ loss += self.original_elbo_weight * loss_vlb
+ loss_dict.update({f"{prefix}/loss": loss})
+
+ return loss, loss_dict
+
+ def p_mean_variance(
+ self,
+ x,
+ c,
+ t,
+ clip_denoised: bool,
+ return_codebook_ids=False,
+ quantize_denoised=False,
+ return_x0=False,
+ score_corrector=None,
+ corrector_kwargs=None,
+ ):
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(
+ self, model_out, x, t, c, **corrector_kwargs
+ )
+
+ if return_codebook_ids:
+ model_out, logits = model_out
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1.0, 1.0)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
+ x_start=x_recon, x_t=x, t=t
+ )
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(
+ self,
+ x,
+ c,
+ t,
+ clip_denoised=False,
+ repeat_noise=False,
+ return_codebook_ids=False,
+ quantize_denoised=False,
+ return_x0=False,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ ):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(
+ x=x,
+ c=c,
+ t=t,
+ clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ )
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.0:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (
+ (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous()
+ )
+
+ # if return_codebook_ids:
+ # return model_mean + nonzero_mask * (
+ # 0.5 * model_log_variance
+ # ).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return (
+ model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
+ x0,
+ )
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def progressive_denoising(
+ self,
+ cond,
+ shape,
+ verbose=True,
+ callback=None,
+ quantize_denoised=False,
+ img_callback=None,
+ mask=None,
+ x0=None,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ batch_size=None,
+ x_T=None,
+ start_T=None,
+ log_every_t=None,
+ ):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ timesteps = self.num_timesteps
+ if batch_size is not None:
+ b = batch_size if batch_size is not None else shape[0]
+ shape = [batch_size] + list(shape)
+ else:
+ b = batch_size = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=self.device)
+ else:
+ img = x_T
+ intermediates = []
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {
+ key: cond[key][:batch_size]
+ if not isinstance(cond[key], list)
+ else list(map(lambda x: x[:batch_size], cond[key]))
+ for key in cond
+ }
+ else:
+ cond = (
+ [c[:batch_size] for c in cond]
+ if isinstance(cond, list)
+ else cond[:batch_size]
+ )
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = (
+ tqdm(
+ reversed(range(0, timesteps)),
+ desc="Progressive Generation",
+ total=timesteps,
+ )
+ if verbose
+ else reversed(range(0, timesteps))
+ )
+ if type(temperature) == float:
+ temperature = [temperature] * timesteps
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != "hybrid"
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img, x0_partial = self.p_sample(
+ img,
+ cond,
+ ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised,
+ return_x0=True,
+ temperature=temperature[i],
+ noise_dropout=noise_dropout,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ )
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1.0 - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(x0_partial)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_loop(
+ self,
+ cond,
+ shape,
+ return_intermediates=False,
+ x_T=None,
+ verbose=True,
+ callback=None,
+ timesteps=None,
+ quantize_denoised=False,
+ mask=None,
+ x0=None,
+ img_callback=None,
+ start_T=None,
+ log_every_t=None,
+ ):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = (
+ tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps)
+ if verbose
+ else reversed(range(0, timesteps))
+ )
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != "hybrid"
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img = self.p_sample(
+ img,
+ cond,
+ ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised,
+ )
+
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1.0 - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(
+ self,
+ cond,
+ batch_size=16,
+ return_intermediates=False,
+ x_T=None,
+ verbose=True,
+ timesteps=None,
+ quantize_denoised=False,
+ mask=None,
+ x0=None,
+ shape=None,
+ **kwargs,
+ ):
+ if shape is None:
+ shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {
+ key: cond[key][:batch_size]
+ if not isinstance(cond[key], list)
+ else list(map(lambda x: x[:batch_size], cond[key]))
+ for key in cond
+ }
+ else:
+ cond = (
+ [c[:batch_size] for c in cond]
+ if isinstance(cond, list)
+ else cond[:batch_size]
+ )
+ return self.p_sample_loop(
+ cond,
+ shape,
+ return_intermediates=return_intermediates,
+ x_T=x_T,
+ verbose=verbose,
+ timesteps=timesteps,
+ quantize_denoised=quantize_denoised,
+ mask=mask,
+ x0=x0,
+ **kwargs,
+ )
+
+ def save_waveform(self, waveform, savepath, name="outwav"):
+ for i in range(waveform.shape[0]):
+ if type(name) is str:
+ path = os.path.join(
+ savepath, "%s_%s_%s.wav" % (self.global_step, i, name)
+ )
+ elif type(name) is list:
+ path = os.path.join(
+ savepath,
+ "%s.wav"
+ % (
+ os.path.basename(name[i])
+ if (not ".wav" in name[i])
+ else os.path.basename(name[i]).split(".")[0]
+ ),
+ )
+ else:
+ raise NotImplementedError
+ todo_waveform = waveform[i, 0]
+ todo_waveform = (
+ todo_waveform / np.max(np.abs(todo_waveform))
+ ) * 0.8 # Normalize the energy of the generation output
+ sf.write(path, todo_waveform, samplerate=self.sampling_rate)
+
+ @torch.no_grad()
+ def sample_log(
+ self,
+ cond,
+ batch_size,
+ ddim,
+ ddim_steps,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ use_plms=False,
+ mask=None,
+ **kwargs,
+ ):
+ if mask is not None:
+ shape = (self.channels, mask.size()[-2], mask.size()[-1])
+ else:
+ shape = (self.channels, self.latent_t_size, self.latent_f_size)
+
+ intermediate = None
+ if ddim and not use_plms:
+ ddim_sampler = DDIMSampler(self)
+ samples, intermediates = ddim_sampler.sample(
+ ddim_steps,
+ batch_size,
+ shape,
+ cond,
+ verbose=False,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ mask=mask,
+ **kwargs,
+ )
+ elif use_plms:
+ plms_sampler = PLMSSampler(self)
+ samples, intermediates = plms_sampler.sample(
+ ddim_steps,
+ batch_size,
+ shape,
+ cond,
+ verbose=False,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ mask=mask,
+ unconditional_conditioning=unconditional_conditioning,
+ **kwargs,
+ )
+
+ else:
+ samples, intermediates = self.sample(
+ cond=cond,
+ batch_size=batch_size,
+ return_intermediates=True,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ mask=mask,
+ unconditional_conditioning=unconditional_conditioning,
+ **kwargs,
+ )
+
+ return samples, intermediate
+
+ @torch.no_grad()
+ def generate_batch(
+ self,
+ batch,
+ ddim_steps=200,
+ ddim_eta=1.0,
+ x_T=None,
+ n_gen=1,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ use_plms=False,
+ **kwargs,
+ ):
+ # Generate n_gen times and select the best
+ # Batch: audio, text, fnames
+ assert x_T is None
+
+ if use_plms:
+ assert ddim_steps is not None
+
+ use_ddim = ddim_steps is not None
+
+ # with self.ema_scope("Plotting"):
+ for i in range(1):
+ z, c = self.get_input(
+ batch,
+ self.first_stage_key,
+ unconditional_prob_cfg=0.0, # Do not output unconditional information in the c
+ )
+
+ c = self.filter_useful_cond_dict(c)
+
+ text = super().get_input(batch, "text")
+
+ # Generate multiple samples
+ batch_size = z.shape[0] * n_gen
+
+ # Generate multiple samples at a time and filter out the best
+ # The condition to the diffusion wrapper can have many format
+ for cond_key in c.keys():
+ if isinstance(c[cond_key], list):
+ for i in range(len(c[cond_key])):
+ c[cond_key][i] = torch.cat([c[cond_key][i]] * n_gen, dim=0)
+ elif isinstance(c[cond_key], dict):
+ for k in c[cond_key].keys():
+ c[cond_key][k] = torch.cat([c[cond_key][k]] * n_gen, dim=0)
+ else:
+ c[cond_key] = torch.cat([c[cond_key]] * n_gen, dim=0)
+
+ text = text * n_gen
+
+ if unconditional_guidance_scale != 1.0:
+ unconditional_conditioning = {}
+ for key in self.cond_stage_model_metadata:
+ model_idx = self.cond_stage_model_metadata[key]["model_idx"]
+ unconditional_conditioning[key] = self.cond_stage_models[
+ model_idx
+ ].get_unconditional_condition(batch_size)
+
+ fnames = list(super().get_input(batch, "fname"))
+ samples, _ = self.sample_log(
+ cond=c,
+ batch_size=batch_size,
+ x_T=x_T,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ use_plms=use_plms,
+ )
+
+ mel = self.decode_first_stage(samples)
+
+ waveform = self.mel_spectrogram_to_waveform(
+ mel, savepath="", bs=None, name=fnames, save=False
+ )
+
+ if n_gen > 1:
+ best_index = []
+ similarity = self.clap.cos_similarity(
+ torch.FloatTensor(waveform).squeeze(1), text
+ )
+ for i in range(z.shape[0]):
+ candidates = similarity[i :: z.shape[0]]
+ max_index = torch.argmax(candidates).item()
+ best_index.append(i + max_index * z.shape[0])
+
+ waveform = waveform[best_index]
+
+ print("Similarity between generated audio and text:")
+ print(' '.join('{:.2f}'.format(num) for num in similarity.detach().cpu().tolist()))
+ print("Choose the following indexes as the output:", best_index)
+
+ return waveform
+
+ @torch.no_grad()
+ def generate_sample(
+ self,
+ batchs,
+ ddim_steps=200,
+ ddim_eta=1.0,
+ x_T=None,
+ n_gen=1,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ name=None,
+ use_plms=False,
+ limit_num=None,
+ **kwargs,
+ ):
+ # Generate n_gen times and select the best
+ # Batch: audio, text, fnames
+ assert x_T is None
+ try:
+ batchs = iter(batchs)
+ except TypeError:
+ raise ValueError("The first input argument should be an iterable object")
+
+ if use_plms:
+ assert ddim_steps is not None
+
+ use_ddim = ddim_steps is not None
+ if name is None:
+ name = self.get_validation_folder_name()
+
+ waveform_save_path = os.path.join(self.get_log_dir(), name)
+ os.makedirs(waveform_save_path, exist_ok=True)
+ print("Waveform save path: ", waveform_save_path)
+
+ if (
+ "audiocaps" in waveform_save_path
+ and len(os.listdir(waveform_save_path)) >= 964
+ ):
+ print("The evaluation has already been done at %s" % waveform_save_path)
+ return waveform_save_path
+
+ with self.ema_scope("Plotting"):
+ for i, batch in enumerate(batchs):
+ z, c = self.get_input(
+ batch,
+ self.first_stage_key,
+ unconditional_prob_cfg=0.0, # Do not output unconditional information in the c
+ )
+
+ if limit_num is not None and i * z.size(0) > limit_num:
+ break
+
+ c = self.filter_useful_cond_dict(c)
+
+ text = super().get_input(batch, "text")
+
+ # Generate multiple samples
+ batch_size = z.shape[0] * n_gen
+
+ # Generate multiple samples at a time and filter out the best
+ # The condition to the diffusion wrapper can have many format
+ for cond_key in c.keys():
+ if isinstance(c[cond_key], list):
+ for i in range(len(c[cond_key])):
+ c[cond_key][i] = torch.cat([c[cond_key][i]] * n_gen, dim=0)
+ elif isinstance(c[cond_key], dict):
+ for k in c[cond_key].keys():
+ c[cond_key][k] = torch.cat([c[cond_key][k]] * n_gen, dim=0)
+ else:
+ c[cond_key] = torch.cat([c[cond_key]] * n_gen, dim=0)
+
+ text = text * n_gen
+
+ if unconditional_guidance_scale != 1.0:
+ unconditional_conditioning = {}
+ for key in self.cond_stage_model_metadata:
+ model_idx = self.cond_stage_model_metadata[key]["model_idx"]
+ unconditional_conditioning[key] = self.cond_stage_models[
+ model_idx
+ ].get_unconditional_condition(batch_size)
+
+ fnames = list(super().get_input(batch, "fname"))
+ samples, _ = self.sample_log(
+ cond=c,
+ batch_size=batch_size,
+ x_T=x_T,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ use_plms=use_plms,
+ )
+
+ mel = self.decode_first_stage(samples)
+
+ waveform = self.mel_spectrogram_to_waveform(
+ mel, savepath=waveform_save_path, bs=None, name=fnames, save=False
+ )
+
+ if n_gen > 1:
+ try:
+ best_index = []
+ similarity = self.clap.cos_similarity(
+ torch.FloatTensor(waveform).squeeze(1), text
+ )
+ for i in range(z.shape[0]):
+ candidates = similarity[i :: z.shape[0]]
+ max_index = torch.argmax(candidates).item()
+ best_index.append(i + max_index * z.shape[0])
+
+ waveform = waveform[best_index]
+
+ print("Similarity between generated audio and text", similarity)
+ print("Choose the following indexes:", best_index)
+ except Exception as e:
+ print("Warning: while calculating CLAP score (not fatal), ", e)
+ self.save_waveform(waveform, waveform_save_path, name=fnames)
+ return waveform_save_path
+
+
+class DiffusionWrapper(nn.Module):
+ def __init__(self, diff_model_config, conditioning_key):
+ super().__init__()
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+
+ self.conditioning_key = conditioning_key
+
+ for key in self.conditioning_key:
+ if (
+ "concat" in key
+ or "crossattn" in key
+ or "hybrid" in key
+ or "film" in key
+ or "noncond" in key
+ ):
+ continue
+ else:
+ raise Value("The conditioning key %s is illegal" % key)
+
+ self.being_verbosed_once = False
+
+ def forward(self, x, t, cond_dict: dict = {}):
+ x = x.contiguous()
+ t = t.contiguous()
+
+ # x with condition (or maybe not)
+ xc = x
+
+ y = None
+ context_list, attn_mask_list = [], []
+
+ conditional_keys = cond_dict.keys()
+
+ for key in conditional_keys:
+ if "concat" in key:
+ xc = torch.cat([x, cond_dict[key].unsqueeze(1)], dim=1)
+ elif "film" in key:
+ if y is None:
+ y = cond_dict[key].squeeze(1)
+ else:
+ y = torch.cat([y, cond_dict[key].squeeze(1)], dim=-1)
+ elif "crossattn" in key:
+ # assert context is None, "You can only have one context matrix, got %s" % (cond_dict.keys())
+ if isinstance(cond_dict[key], dict):
+ for k in cond_dict[key].keys():
+ if "crossattn" in k:
+ context, attn_mask = cond_dict[key][
+ k
+ ] # crossattn_audiomae_pooled: torch.Size([12, 128, 768])
+ else:
+ assert len(cond_dict[key]) == 2, (
+ "The context condition for %s you returned should have two element, one context one mask"
+ % (key)
+ )
+ context, attn_mask = cond_dict[key]
+
+ # The input to the UNet model is a list of context matrix
+ context_list.append(context)
+ attn_mask_list.append(attn_mask)
+
+ elif (
+ "noncond" in key
+ ): # If you use loss function in the conditional module, include the keyword "noncond" in the return dictionary
+ continue
+ else:
+ raise NotImplementedError()
+
+ # if(not self.being_verbosed_once):
+ # print("The input shape to the diffusion model is as follows:")
+ # print("xc", xc.size())
+ # print("t", t.size())
+ # for i in range(len(context_list)):
+ # print("context_%s" % i, context_list[i].size(), attn_mask_list[i].size())
+ # if(y is not None):
+ # print("y", y.size())
+ # self.being_verbosed_once = True
+ out = self.diffusion_model(
+ xc, t, context_list=context_list, y=y, context_attn_mask_list=attn_mask_list
+ )
+ return out
+ self.warmup_step()
+
+ if (
+ self.state is None
+ and len(self.trainer.optimizers[0].state_dict()["state"].keys()) > 0
+ ):
+ self.state = (
+ self.trainer.optimizers[0].state_dict()["state"][0]["exp_avg"].clone()
+ )
+ elif self.state is not None and batch_idx % 1000 == 0:
+ assert (
+ torch.sum(
+ torch.abs(
+ self.state
+ - self.trainer.optimizers[0].state_dict()["state"][0]["exp_avg"]
+ )
+ )
+ > 1e-7
+ ), "Optimizer is not working"
+
+ if len(self.metrics_buffer.keys()) > 0:
+ for k in self.metrics_buffer.keys():
+ self.log(
+ k,
+ self.metrics_buffer[k],
+ prog_bar=False,
+ logger=True,
+ on_step=True,
+ on_epoch=False,
+ )
+ print(k, self.metrics_buffer[k])
+ self.metrics_buffer = {}
+
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(
+ {k: float(v) for k, v in loss_dict.items()},
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+
+ self.log(
+ "global_step",
+ float(self.global_step),
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=False,
+ )
+
+ lr = self.trainer.optimizers[0].param_groups[0]["lr"]
+ self.log(
+ "lr_abs",
+ float(lr),
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=False,
+ )
+
+
+if __name__ == "__main__":
+ import yaml
+
+ model_config = "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/stable-diffusion/models/ldm/text2img256/config.yaml"
+ model_config = yaml.load(open(model_config, "r"), Loader=yaml.FullLoader)
+
+ latent_diffusion = LatentDiffusion(**model_config["model"]["params"])
+
+ import ipdb
+
+ ipdb.set_trace()
diff --git a/audioldm2/latent_diffusion/models/plms.py b/audioldm2/latent_diffusion/models/plms.py
new file mode 100755
index 0000000000000000000000000000000000000000..9c80796442bd653ac3dc1970c12f621068a4d821
--- /dev/null
+++ b/audioldm2/latent_diffusion/models/plms.py
@@ -0,0 +1,360 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from audioldm2.latent_diffusion.modules.diffusionmodules.util import (
+ make_ddim_sampling_parameters,
+ make_ddim_timesteps,
+ noise_like,
+)
+
+
+class PLMSSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
+ ):
+ if ddim_eta != 0:
+ ddim_eta = 0
+ # raise ValueError('ddim_eta must be 0 for PLMS')
+
+ self.ddim_timesteps = make_ddim_timesteps(
+ ddim_discr_method=ddim_discretize,
+ num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
+ verbose=verbose,
+ )
+ alphas_cumprod = self.model.alphas_cumprod
+ assert (
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
+ ), "alphas have to be defined for each timestep"
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer("betas", to_torch(self.model.betas))
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+ self.register_buffer(
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
+ )
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer(
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_one_minus_alphas_cumprod",
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
+ )
+ self.register_buffer(
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod",
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
+ )
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
+ alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,
+ verbose=verbose,
+ )
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
+ self.register_buffer("ddim_alphas", ddim_alphas)
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev)
+ / (1 - self.alphas_cumprod)
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
+ )
+ self.register_buffer(
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
+ )
+
+ @torch.no_grad()
+ def sample(
+ self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.0,
+ mask=None,
+ x0=None,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs,
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
+ )
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
+ )
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f"Data shape for PLMS sampling is {size}")
+
+ samples, intermediates = self.plms_sampling(
+ conditioning,
+ size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask,
+ x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def plms_sampling(
+ self,
+ cond,
+ shape,
+ x_T=None,
+ ddim_use_original_steps=False,
+ callback=None,
+ timesteps=None,
+ quantize_denoised=False,
+ mask=None,
+ x0=None,
+ img_callback=None,
+ log_every_t=100,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ ):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = (
+ self.ddpm_num_timesteps
+ if ddim_use_original_steps
+ else self.ddim_timesteps
+ )
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = (
+ int(
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
+ * self.ddim_timesteps.shape[0]
+ )
+ - 1
+ )
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
+ time_range = (
+ list(reversed(range(0, timesteps)))
+ if ddim_use_original_steps
+ else np.flip(timesteps)
+ )
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps)
+ old_eps = []
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ ts_next = torch.full(
+ (b,),
+ time_range[min(i + 1, len(time_range) - 1)],
+ device=device,
+ dtype=torch.long,
+ )
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(
+ x0, ts
+ ) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1.0 - mask) * img
+
+ outs = self.p_sample_plms(
+ img,
+ cond,
+ ts,
+ index=index,
+ use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised,
+ temperature=temperature,
+ noise_dropout=noise_dropout,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ old_eps=old_eps,
+ t_next=ts_next,
+ )
+ img, pred_x0, e_t = outs
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates["x_inter"].append(img)
+ intermediates["pred_x0"].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_plms(
+ self,
+ x,
+ c,
+ t,
+ index,
+ repeat_noise=False,
+ use_original_steps=False,
+ quantize_denoised=False,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ old_eps=None,
+ t_next=None,
+ ):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if (
+ unconditional_conditioning is None
+ or unconditional_guidance_scale == 1.0
+ ):
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(
+ self.model, e_t, x, t, c, **corrector_kwargs
+ )
+
+ return e_t
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = (
+ self.model.alphas_cumprod_prev
+ if use_original_steps
+ else self.ddim_alphas_prev
+ )
+ sqrt_one_minus_alphas = (
+ self.model.sqrt_one_minus_alphas_cumprod
+ if use_original_steps
+ else self.ddim_sqrt_one_minus_alphas
+ )
+ sigmas = (
+ self.model.ddim_sigmas_for_original_num_steps
+ if use_original_steps
+ else self.ddim_sigmas
+ )
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full(
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
+ )
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.0:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (
+ 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
+ ) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
diff --git a/audioldm2/latent_diffusion/modules/__init__.py b/audioldm2/latent_diffusion/modules/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/latent_diffusion/modules/attention.py b/audioldm2/latent_diffusion/modules/attention.py
new file mode 100755
index 0000000000000000000000000000000000000000..6116342da98249c681ddb5f696b48dc0f5ac69f2
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/attention.py
@@ -0,0 +1,467 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+
+from audioldm2.latent_diffusion.modules.diffusionmodules.util import checkpoint
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return {el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = (
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+ if not glu
+ else GEGLU(dim, inner_dim)
+ )
+
+ self.net = nn.Sequential(
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
+ )
+ k = k.softmax(dim=-1)
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
+ out = rearrange(
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
+ )
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = rearrange(q, "b c h w -> b (h w) c")
+ k = rearrange(k, "b c h w -> b c (h w)")
+ w_ = torch.einsum("bij,bjk->bik", q, k)
+
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, "b c h w -> b c (h w)")
+ w_ = rearrange(w_, "b i j -> b j i")
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+# class CrossAttention(nn.Module):
+# """
+# ### Cross Attention Layer
+# This falls-back to self-attention when conditional embeddings are not specified.
+# """
+
+# use_flash_attention: bool = True
+
+# # use_flash_attention: bool = False
+# def __init__(
+# self,
+# query_dim,
+# context_dim=None,
+# heads=8,
+# dim_head=64,
+# dropout=0.0,
+# is_inplace: bool = True,
+# ):
+# # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
+# """
+# :param d_model: is the input embedding size
+# :param n_heads: is the number of attention heads
+# :param d_head: is the size of a attention head
+# :param d_cond: is the size of the conditional embeddings
+# :param is_inplace: specifies whether to perform the attention softmax computation inplace to
+# save memory
+# """
+# super().__init__()
+
+# self.is_inplace = is_inplace
+# self.n_heads = heads
+# self.d_head = dim_head
+
+# # Attention scaling factor
+# self.scale = dim_head**-0.5
+
+# # The normal self-attention layer
+# if context_dim is None:
+# context_dim = query_dim
+
+# # Query, key and value mappings
+# d_attn = dim_head * heads
+# self.to_q = nn.Linear(query_dim, d_attn, bias=False)
+# self.to_k = nn.Linear(context_dim, d_attn, bias=False)
+# self.to_v = nn.Linear(context_dim, d_attn, bias=False)
+
+# # Final linear layer
+# self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout))
+
+# # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
+# # Flash attention is only used if it's installed
+# # and `CrossAttention.use_flash_attention` is set to `True`.
+# try:
+# # You can install flash attention by cloning their Github repo,
+# # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
+# # and then running `python setup.py install`
+# from flash_attn.flash_attention import FlashAttention
+
+# self.flash = FlashAttention()
+# # Set the scale for scaled dot-product attention.
+# self.flash.softmax_scale = self.scale
+# # Set to `None` if it's not installed
+# except ImportError:
+# self.flash = None
+
+# def forward(self, x, context=None, mask=None):
+# """
+# :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
+# :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
+# """
+
+# # If `cond` is `None` we perform self attention
+# has_cond = context is not None
+# if not has_cond:
+# context = x
+
+# # Get query, key and value vectors
+# q = self.to_q(x)
+# k = self.to_k(context)
+# v = self.to_v(context)
+
+# # Use flash attention if it's available and the head size is less than or equal to `128`
+# if (
+# CrossAttention.use_flash_attention
+# and self.flash is not None
+# and not has_cond
+# and self.d_head <= 128
+# ):
+# return self.flash_attention(q, k, v)
+# # Otherwise, fallback to normal attention
+# else:
+# return self.normal_attention(q, k, v)
+
+# def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
+# """
+# #### Flash Attention
+# :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+# :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+# :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+# """
+
+# # Get batch size and number of elements along sequence axis (`width * height`)
+# batch_size, seq_len, _ = q.shape
+
+# # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
+# # shape `[batch_size, seq_len, 3, n_heads * d_head]`
+# qkv = torch.stack((q, k, v), dim=2)
+# # Split the heads
+# qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
+
+# # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
+# # fit this size.
+# if self.d_head <= 32:
+# pad = 32 - self.d_head
+# elif self.d_head <= 64:
+# pad = 64 - self.d_head
+# elif self.d_head <= 128:
+# pad = 128 - self.d_head
+# else:
+# raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
+
+# # Pad the heads
+# if pad:
+# qkv = torch.cat(
+# (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
+# )
+
+# # Compute attention
+# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
+# # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
+# # TODO here I add the dtype changing
+# out, _ = self.flash(qkv.type(torch.float16))
+# # Truncate the extra head size
+# out = out[:, :, :, : self.d_head].float()
+# # Reshape to `[batch_size, seq_len, n_heads * d_head]`
+# out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
+
+# # Map to `[batch_size, height * width, d_model]` with a linear layer
+# return self.to_out(out)
+
+# def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
+# """
+# #### Normal Attention
+
+# :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+# :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+# :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+# """
+
+# # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
+# q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32]
+# k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32]
+# v = v.view(*v.shape[:2], self.n_heads, -1)
+
+# # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
+# attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
+
+# # Compute softmax
+# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
+# if self.is_inplace:
+# half = attn.shape[0] // 2
+# attn[half:] = attn[half:].softmax(dim=-1)
+# attn[:half] = attn[:half].softmax(dim=-1)
+# else:
+# attn = attn.softmax(dim=-1)
+
+# # Compute attention output
+# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
+# # attn: [bs, 20, 64, 1]
+# # v: [bs, 1, 20, 32]
+# out = torch.einsum("bhij,bjhd->bihd", attn, v)
+# # Reshape to `[batch_size, height * width, n_heads * d_head]`
+# out = out.reshape(*out.shape[:2], -1)
+# # Map to `[batch_size, height * width, d_model]` with a linear layer
+# return self.to_out(out)
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
+
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
+
+ if exists(mask):
+ mask = rearrange(mask, "b ... -> b (...)")
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
+ sim.masked_fill_(~(mask == 1), max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum("b i j, b j d -> b i d", attn, v)
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ ):
+ super().__init__()
+ self.attn1 = CrossAttention(
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is a self-attention
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = CrossAttention(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ ) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None, mask=None):
+ if context is None:
+ return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
+ else:
+ return checkpoint(
+ self._forward, (x, context, mask), self.parameters(), self.checkpoint
+ )
+
+ def _forward(self, x, context=None, mask=None):
+ x = self.attn1(self.norm1(x)) + x
+ x = self.attn2(self.norm2(x), context=context, mask=mask) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ context_dim=None,
+ ):
+ super().__init__()
+
+ context_dim = context_dim
+
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+
+ self.proj_in = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+ )
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
+ )
+ for d in range(depth)
+ ]
+ )
+
+ self.proj_out = zero_module(
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+ x = rearrange(x, "b c h w -> b (h w) c")
+ for block in self.transformer_blocks:
+ x = block(x, context=context, mask=mask)
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+ x = self.proj_out(x)
+ return x + x_in
diff --git a/audioldm2/latent_diffusion/modules/audiomae/AudioMAE.py b/audioldm2/latent_diffusion/modules/audiomae/AudioMAE.py
new file mode 100755
index 0000000000000000000000000000000000000000..f02fa05e163076641b92bbeabceb5f89edb0f18e
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/AudioMAE.py
@@ -0,0 +1,149 @@
+"""
+Reference Repo: https://github.com/facebookresearch/AudioMAE
+"""
+
+import torch
+import torch.nn as nn
+from timm.models.layers import to_2tuple
+import audioldm2.latent_diffusion.modules.audiomae.models_vit as models_vit
+import audioldm2.latent_diffusion.modules.audiomae.models_mae as models_mae
+
+# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
+
+
+class PatchEmbed_new(nn.Module):
+ """Flexible Image to Patch Embedding"""
+
+ def __init__(
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
+ ):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ stride = to_2tuple(stride)
+
+ self.img_size = img_size
+ self.patch_size = patch_size
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=patch_size, stride=stride
+ ) # with overlapped patches
+ # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
+ # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
+ self.patch_hw = (h, w)
+ self.num_patches = h * w
+
+ def get_output_shape(self, img_size):
+ # todo: don't be lazy..
+ return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ # assert H == self.img_size[0] and W == self.img_size[1], \
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x)
+ x = x.flatten(2).transpose(1, 2)
+ return x
+
+
+class AudioMAE(nn.Module):
+ """Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)"""
+
+ def __init__(
+ self,
+ ):
+ super().__init__()
+ model = models_vit.__dict__["vit_base_patch16"](
+ num_classes=527,
+ drop_path_rate=0.1,
+ global_pool=True,
+ mask_2d=True,
+ use_custom_patch=False,
+ )
+
+ img_size = (1024, 128)
+ emb_dim = 768
+
+ model.patch_embed = PatchEmbed_new(
+ img_size=img_size,
+ patch_size=(16, 16),
+ in_chans=1,
+ embed_dim=emb_dim,
+ stride=16,
+ )
+ num_patches = model.patch_embed.num_patches
+ # num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8
+ model.pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False
+ ) # fixed sin-cos embedding
+
+ # checkpoint_path = '/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth'
+ # checkpoint = torch.load(checkpoint_path, map_location='cpu')
+ # msg = model.load_state_dict(checkpoint['model'], strict=False)
+ # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
+
+ self.model = model
+
+ def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0):
+ """
+ x: mel fbank [Batch, 1, T, F]
+ mask_t_prob: 'T masking ratio (percentage of removed patches).'
+ mask_f_prob: 'F masking ratio (percentage of removed patches).'
+ """
+ return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob)
+
+
+class Vanilla_AudioMAE(nn.Module):
+ """Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM2)"""
+
+ def __init__(
+ self,
+ ):
+ super().__init__()
+ model = models_mae.__dict__["mae_vit_base_patch16"](
+ in_chans=1, audio_exp=True, img_size=(1024, 128)
+ )
+
+ # checkpoint_path = '/mnt/bn/lqhaoheliu/exps/checkpoints/audiomae/pretrained.pth'
+ # checkpoint = torch.load(checkpoint_path, map_location='cpu')
+ # msg = model.load_state_dict(checkpoint['model'], strict=False)
+
+ # Skip the missing keys of decoder modules (not required)
+ # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
+
+ self.model = model.eval()
+
+ def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False):
+ """
+ x: mel fbank [Batch, 1, 1024 (T), 128 (F)]
+ mask_ratio: 'masking ratio (percentage of removed patches).'
+ """
+ with torch.no_grad():
+ # embed: [B, 513, 768] for mask_ratio=0.0
+ if no_mask:
+ if no_average:
+ raise RuntimeError("This function is deprecated")
+ embed = self.model.forward_encoder_no_random_mask_no_average(
+ x
+ ) # mask_ratio
+ else:
+ embed = self.model.forward_encoder_no_mask(x) # mask_ratio
+ else:
+ raise RuntimeError("This function is deprecated")
+ embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio)
+ return embed
+
+
+if __name__ == "__main__":
+ model = Vanilla_AudioMAE().cuda()
+ input = torch.randn(4, 1, 1024, 128).cuda()
+ print("The first run")
+ embed = model(input, mask_ratio=0.0, no_mask=True)
+ print(embed)
+ print("The second run")
+ embed = model(input, mask_ratio=0.0)
+ print(embed)
diff --git a/audioldm2/latent_diffusion/modules/audiomae/__init__.py b/audioldm2/latent_diffusion/modules/audiomae/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/latent_diffusion/modules/audiomae/models_mae.py b/audioldm2/latent_diffusion/modules/audiomae/models_mae.py
new file mode 100755
index 0000000000000000000000000000000000000000..7ab0076710a08a7451dd4096bd6eb2f8f6e641aa
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/models_mae.py
@@ -0,0 +1,613 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from timm.models.vision_transformer import Block
+from audioldm2.latent_diffusion.modules.audiomae.util.pos_embed import (
+ get_2d_sincos_pos_embed,
+ get_2d_sincos_pos_embed_flexible,
+)
+from audioldm2.latent_diffusion.modules.audiomae.util.patch_embed import (
+ PatchEmbed_new,
+ PatchEmbed_org,
+)
+
+
+class MaskedAutoencoderViT(nn.Module):
+ """Masked Autoencoder with VisionTransformer backbone"""
+
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ stride=10,
+ in_chans=3,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ decoder_embed_dim=512,
+ decoder_depth=8,
+ decoder_num_heads=16,
+ mlp_ratio=4.0,
+ norm_layer=nn.LayerNorm,
+ norm_pix_loss=False,
+ audio_exp=False,
+ alpha=0.0,
+ temperature=0.2,
+ mode=0,
+ contextual_depth=8,
+ use_custom_patch=False,
+ split_pos=False,
+ pos_trainable=False,
+ use_nce=False,
+ beta=4.0,
+ decoder_mode=0,
+ mask_t_prob=0.6,
+ mask_f_prob=0.5,
+ mask_2d=False,
+ epoch=0,
+ no_shift=False,
+ ):
+ super().__init__()
+
+ self.audio_exp = audio_exp
+ self.embed_dim = embed_dim
+ self.decoder_embed_dim = decoder_embed_dim
+ # --------------------------------------------------------------------------
+ # MAE encoder specifics
+ if use_custom_patch:
+ print(
+ f"Use custom patch_emb with patch size: {patch_size}, stride: {stride}"
+ )
+ self.patch_embed = PatchEmbed_new(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ stride=stride,
+ )
+ else:
+ self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim)
+ self.use_custom_patch = use_custom_patch
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+
+ # self.split_pos = split_pos # not useful
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable
+ ) # fixed sin-cos embedding
+
+ self.encoder_depth = depth
+ self.contextual_depth = contextual_depth
+ self.blocks = nn.ModuleList(
+ [
+ Block(
+ embed_dim,
+ num_heads,
+ mlp_ratio,
+ qkv_bias=True,
+ norm_layer=norm_layer,
+ ) # qk_scale=None
+ for i in range(depth)
+ ]
+ )
+ self.norm = norm_layer(embed_dim)
+
+ # --------------------------------------------------------------------------
+ # MAE decoder specifics
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
+
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
+ self.decoder_pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches + 1, decoder_embed_dim),
+ requires_grad=pos_trainable,
+ ) # fixed sin-cos embedding
+
+ self.no_shift = no_shift
+
+ self.decoder_mode = decoder_mode
+ if (
+ self.use_custom_patch
+ ): # overlapped patches as in AST. Similar performance yet compute heavy
+ window_size = (6, 6)
+ feat_size = (102, 12)
+ else:
+ window_size = (4, 4)
+ feat_size = (64, 8)
+ if self.decoder_mode == 1:
+ decoder_modules = []
+ for index in range(16):
+ if self.no_shift:
+ shift_size = (0, 0)
+ else:
+ if (index % 2) == 0:
+ shift_size = (0, 0)
+ else:
+ shift_size = (2, 0)
+ # shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size])
+ decoder_modules.append(
+ SwinTransformerBlock(
+ dim=decoder_embed_dim,
+ num_heads=16,
+ feat_size=feat_size,
+ window_size=window_size,
+ shift_size=shift_size,
+ mlp_ratio=mlp_ratio,
+ drop=0.0,
+ drop_attn=0.0,
+ drop_path=0.0,
+ extra_norm=False,
+ sequential_attn=False,
+ norm_layer=norm_layer, # nn.LayerNorm,
+ )
+ )
+ self.decoder_blocks = nn.ModuleList(decoder_modules)
+ else:
+ # Transfomer
+ self.decoder_blocks = nn.ModuleList(
+ [
+ Block(
+ decoder_embed_dim,
+ decoder_num_heads,
+ mlp_ratio,
+ qkv_bias=True,
+ norm_layer=norm_layer,
+ ) # qk_scale=None,
+ for i in range(decoder_depth)
+ ]
+ )
+
+ self.decoder_norm = norm_layer(decoder_embed_dim)
+ self.decoder_pred = nn.Linear(
+ decoder_embed_dim, patch_size**2 * in_chans, bias=True
+ ) # decoder to patch
+
+ # --------------------------------------------------------------------------
+
+ self.norm_pix_loss = norm_pix_loss
+
+ self.patch_size = patch_size
+ self.stride = stride
+
+ # audio exps
+ self.alpha = alpha
+ self.T = temperature
+ self.mode = mode
+ self.use_nce = use_nce
+ self.beta = beta
+
+ self.log_softmax = nn.LogSoftmax(dim=-1)
+
+ self.mask_t_prob = mask_t_prob
+ self.mask_f_prob = mask_f_prob
+ self.mask_2d = mask_2d
+
+ self.epoch = epoch
+
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ # initialization
+ # initialize (and freeze) pos_embed by sin-cos embedding
+ if self.audio_exp:
+ pos_embed = get_2d_sincos_pos_embed_flexible(
+ self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True
+ )
+ else:
+ pos_embed = get_2d_sincos_pos_embed(
+ self.pos_embed.shape[-1],
+ int(self.patch_embed.num_patches**0.5),
+ cls_token=True,
+ )
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ if self.audio_exp:
+ decoder_pos_embed = get_2d_sincos_pos_embed_flexible(
+ self.decoder_pos_embed.shape[-1],
+ self.patch_embed.patch_hw,
+ cls_token=True,
+ )
+ else:
+ decoder_pos_embed = get_2d_sincos_pos_embed(
+ self.decoder_pos_embed.shape[-1],
+ int(self.patch_embed.num_patches**0.5),
+ cls_token=True,
+ )
+ self.decoder_pos_embed.data.copy_(
+ torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
+ )
+
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
+ w = self.patch_embed.proj.weight.data
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
+ torch.nn.init.normal_(self.cls_token, std=0.02)
+ torch.nn.init.normal_(self.mask_token, std=0.02)
+
+ # initialize nn.Linear and nn.LayerNorm
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ # we use xavier_uniform following official JAX ViT:
+ torch.nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def patchify(self, imgs):
+ """
+ imgs: (N, 3, H, W)
+ x: (N, L, patch_size**2 *3)
+ L = (H/p)*(W/p)
+ """
+ p = self.patch_embed.patch_size[0]
+ # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
+
+ if self.audio_exp:
+ if self.use_custom_patch: # overlapped patch
+ h, w = self.patch_embed.patch_hw
+ # todo: fixed h/w patch size and stride size. Make hw custom in the future
+ x = imgs.unfold(2, self.patch_size, self.stride).unfold(
+ 3, self.patch_size, self.stride
+ ) # n,1,H,W -> n,1,h,w,p,p
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
+ # x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
+ # x = torch.einsum('nchpwq->nhwpqc', x)
+ # x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
+ else:
+ h = imgs.shape[2] // p
+ w = imgs.shape[3] // p
+ # h,w = self.patch_embed.patch_hw
+ x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
+ x = torch.einsum("nchpwq->nhwpqc", x)
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
+ else:
+ h = w = imgs.shape[2] // p
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
+ x = torch.einsum("nchpwq->nhwpqc", x)
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
+
+ return x
+
+ def unpatchify(self, x):
+ """
+ x: (N, L, patch_size**2 *3)
+ specs: (N, 1, H, W)
+ """
+ p = self.patch_embed.patch_size[0]
+ h = 1024 // p
+ w = 128 // p
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 1))
+ x = torch.einsum("nhwpqc->nchpwq", x)
+ specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p))
+ return specs
+
+ def random_masking(self, x, mask_ratio):
+ """
+ Perform per-sample random masking by per-sample shuffling.
+ Per-sample shuffling is done by argsort random noise.
+ x: [N, L, D], sequence
+ """
+ N, L, D = x.shape # batch, length, dim
+ len_keep = int(L * (1 - mask_ratio))
+
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
+
+ # sort noise for each sample
+ ids_shuffle = torch.argsort(
+ noise, dim=1
+ ) # ascend: small is keep, large is remove
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
+
+ # keep the first subset
+ ids_keep = ids_shuffle[:, :len_keep]
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
+
+ # generate the binary mask: 0 is keep, 1 is remove
+ mask = torch.ones([N, L], device=x.device)
+ mask[:, :len_keep] = 0
+ # unshuffle to get the binary mask
+ mask = torch.gather(mask, dim=1, index=ids_restore)
+
+ return x_masked, mask, ids_restore
+
+ def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
+ """
+ 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
+ Perform per-sample random masking by per-sample shuffling.
+ Per-sample shuffling is done by argsort random noise.
+ x: [N, L, D], sequence
+ """
+ N, L, D = x.shape # batch, length, dim
+ if self.use_custom_patch: # overlapped patch
+ T = 101
+ F = 12
+ else:
+ T = 64
+ F = 8
+ # x = x.reshape(N, T, F, D)
+ len_keep_t = int(T * (1 - mask_t_prob))
+ len_keep_f = int(F * (1 - mask_f_prob))
+
+ # noise for mask in time
+ noise_t = torch.rand(N, T, device=x.device) # noise in [0, 1]
+ # sort noise for each sample aling time
+ ids_shuffle_t = torch.argsort(
+ noise_t, dim=1
+ ) # ascend: small is keep, large is remove
+ ids_restore_t = torch.argsort(ids_shuffle_t, dim=1)
+ ids_keep_t = ids_shuffle_t[:, :len_keep_t]
+ # noise mask in freq
+ noise_f = torch.rand(N, F, device=x.device) # noise in [0, 1]
+ ids_shuffle_f = torch.argsort(
+ noise_f, dim=1
+ ) # ascend: small is keep, large is remove
+ ids_restore_f = torch.argsort(ids_shuffle_f, dim=1)
+ ids_keep_f = ids_shuffle_f[:, :len_keep_f] #
+
+ # generate the binary mask: 0 is keep, 1 is remove
+ # mask in freq
+ mask_f = torch.ones(N, F, device=x.device)
+ mask_f[:, :len_keep_f] = 0
+ mask_f = (
+ torch.gather(mask_f, dim=1, index=ids_restore_f)
+ .unsqueeze(1)
+ .repeat(1, T, 1)
+ ) # N,T,F
+ # mask in time
+ mask_t = torch.ones(N, T, device=x.device)
+ mask_t[:, :len_keep_t] = 0
+ mask_t = (
+ torch.gather(mask_t, dim=1, index=ids_restore_t)
+ .unsqueeze(1)
+ .repeat(1, F, 1)
+ .permute(0, 2, 1)
+ ) # N,T,F
+ mask = 1 - (1 - mask_t) * (1 - mask_f) # N, T, F
+
+ # get masked x
+ id2res = torch.Tensor(list(range(N * T * F))).reshape(N, T, F).to(x.device)
+ id2res = id2res + 999 * mask # add a large value for masked elements
+ id2res2 = torch.argsort(id2res.flatten(start_dim=1))
+ ids_keep = id2res2.flatten(start_dim=1)[:, : len_keep_f * len_keep_t]
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
+
+ ids_restore = torch.argsort(id2res2.flatten(start_dim=1))
+ mask = mask.flatten(start_dim=1)
+
+ return x_masked, mask, ids_restore
+
+ def forward_encoder(self, x, mask_ratio, mask_2d=False):
+ # embed patches
+ x = self.patch_embed(x)
+ # add pos embed w/o cls token
+ x = x + self.pos_embed[:, 1:, :]
+
+ # masking: length -> length * mask_ratio
+ if mask_2d:
+ x, mask, ids_restore = self.random_masking_2d(
+ x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob
+ )
+ else:
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
+
+ # append cls token
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ # apply Transformer blocks
+ for blk in self.blocks:
+ x = blk(x)
+ x = self.norm(x)
+
+ return x, mask, ids_restore, None
+
+ def forward_encoder_no_random_mask_no_average(self, x):
+ # embed patches
+ x = self.patch_embed(x)
+ # add pos embed w/o cls token
+ x = x + self.pos_embed[:, 1:, :]
+
+ # masking: length -> length * mask_ratio
+ # if mask_2d:
+ # x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob)
+ # else:
+ # x, mask, ids_restore = self.random_masking(x, mask_ratio)
+
+ # append cls token
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ # apply Transformer blocks
+ for blk in self.blocks:
+ x = blk(x)
+ x = self.norm(x)
+
+ return x
+
+ def forward_encoder_no_mask(self, x):
+ # embed patches
+ x = self.patch_embed(x)
+
+ # add pos embed w/o cls token
+ x = x + self.pos_embed[:, 1:, :]
+
+ # masking: length -> length * mask_ratio
+ # x, mask, ids_restore = self.random_masking(x, mask_ratio)
+ # append cls token
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ # apply Transformer blocks
+ contextual_embs = []
+ for n, blk in enumerate(self.blocks):
+ x = blk(x)
+ if n > self.contextual_depth:
+ contextual_embs.append(self.norm(x))
+ # x = self.norm(x)
+ contextual_emb = torch.stack(contextual_embs, dim=0).mean(dim=0)
+
+ return contextual_emb
+
+ def forward_decoder(self, x, ids_restore):
+ # embed tokens
+ x = self.decoder_embed(x)
+
+ # append mask tokens to sequence
+ mask_tokens = self.mask_token.repeat(
+ x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
+ )
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
+ x_ = torch.gather(
+ x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
+ ) # unshuffle
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
+
+ # add pos embed
+ x = x + self.decoder_pos_embed
+
+ if self.decoder_mode != 0:
+ B, L, D = x.shape
+ x = x[:, 1:, :]
+ if self.use_custom_patch:
+ x = x.reshape(B, 101, 12, D)
+ x = torch.cat([x, x[:, -1, :].unsqueeze(1)], dim=1) # hack
+ x = x.reshape(B, 1224, D)
+ if self.decoder_mode > 3: # mvit
+ x = self.decoder_blocks(x)
+ else:
+ # apply Transformer blocks
+ for blk in self.decoder_blocks:
+ x = blk(x)
+ x = self.decoder_norm(x)
+
+ # predictor projection
+ pred = self.decoder_pred(x)
+
+ # remove cls token
+ if self.decoder_mode != 0:
+ if self.use_custom_patch:
+ pred = pred.reshape(B, 102, 12, 256)
+ pred = pred[:, :101, :, :]
+ pred = pred.reshape(B, 1212, 256)
+ else:
+ pred = pred
+ else:
+ pred = pred[:, 1:, :]
+ return pred, None, None # emb, emb_pixel
+
+ def forward_loss(self, imgs, pred, mask, norm_pix_loss=False):
+ """
+ imgs: [N, 3, H, W]
+ pred: [N, L, p*p*3]
+ mask: [N, L], 0 is keep, 1 is remove,
+ """
+ target = self.patchify(imgs)
+ if norm_pix_loss:
+ mean = target.mean(dim=-1, keepdim=True)
+ var = target.var(dim=-1, keepdim=True)
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
+
+ loss = (pred - target) ** 2
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
+
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
+ return loss
+
+ def forward(self, imgs, mask_ratio=0.8):
+ emb_enc, mask, ids_restore, _ = self.forward_encoder(
+ imgs, mask_ratio, mask_2d=self.mask_2d
+ )
+ pred, _, _ = self.forward_decoder(emb_enc, ids_restore) # [N, L, p*p*3]
+ loss_recon = self.forward_loss(
+ imgs, pred, mask, norm_pix_loss=self.norm_pix_loss
+ )
+ loss_contrastive = torch.FloatTensor([0.0]).cuda()
+ return loss_recon, pred, mask, loss_contrastive
+
+
+def mae_vit_small_patch16_dec512d8b(**kwargs):
+ model = MaskedAutoencoderViT(
+ patch_size=16,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ decoder_embed_dim=512,
+ decoder_num_heads=16,
+ mlp_ratio=4,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+def mae_vit_base_patch16_dec512d8b(**kwargs):
+ model = MaskedAutoencoderViT(
+ patch_size=16,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ decoder_embed_dim=512,
+ decoder_num_heads=16,
+ mlp_ratio=4,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+def mae_vit_large_patch16_dec512d8b(**kwargs):
+ model = MaskedAutoencoderViT(
+ patch_size=16,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ decoder_embed_dim=512,
+ decoder_num_heads=16,
+ mlp_ratio=4,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+def mae_vit_huge_patch14_dec512d8b(**kwargs):
+ model = MaskedAutoencoderViT(
+ patch_size=14,
+ embed_dim=1280,
+ depth=32,
+ num_heads=16,
+ decoder_embed_dim=512,
+ decoder_num_heads=16,
+ mlp_ratio=4,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs,
+ )
+ return model
+
+
+# set recommended archs
+mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
+mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
+mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
+mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks
diff --git a/audioldm2/latent_diffusion/modules/audiomae/models_vit.py b/audioldm2/latent_diffusion/modules/audiomae/models_vit.py
new file mode 100755
index 0000000000000000000000000000000000000000..cb37adbc16cfb9a232493c473c9400f199655b6c
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/models_vit.py
@@ -0,0 +1,243 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+
+from functools import partial
+
+import torch
+import torch.nn as nn
+import timm.models.vision_transformer
+
+
+class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
+ """Vision Transformer with support for global average pooling"""
+
+ def __init__(
+ self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs
+ ):
+ super(VisionTransformer, self).__init__(**kwargs)
+
+ self.global_pool = global_pool
+ if self.global_pool:
+ norm_layer = kwargs["norm_layer"]
+ embed_dim = kwargs["embed_dim"]
+ self.fc_norm = norm_layer(embed_dim)
+ del self.norm # remove the original norm
+ self.mask_2d = mask_2d
+ self.use_custom_patch = use_custom_patch
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+ x = x + self.pos_embed[:, 1:, :]
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
+ cls_tokens = cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ if self.global_pool:
+ x = x[:, 1:, :].mean(dim=1) # global pool without cls token
+ outcome = self.fc_norm(x)
+ else:
+ x = self.norm(x)
+ outcome = x[:, 0]
+
+ return outcome
+
+ def random_masking(self, x, mask_ratio):
+ """
+ Perform per-sample random masking by per-sample shuffling.
+ Per-sample shuffling is done by argsort random noise.
+ x: [N, L, D], sequence
+ """
+ N, L, D = x.shape # batch, length, dim
+ len_keep = int(L * (1 - mask_ratio))
+
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
+
+ # sort noise for each sample
+ ids_shuffle = torch.argsort(
+ noise, dim=1
+ ) # ascend: small is keep, large is remove
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
+
+ # keep the first subset
+ ids_keep = ids_shuffle[:, :len_keep]
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
+
+ # generate the binary mask: 0 is keep, 1 is remove
+ mask = torch.ones([N, L], device=x.device)
+ mask[:, :len_keep] = 0
+ # unshuffle to get the binary mask
+ mask = torch.gather(mask, dim=1, index=ids_restore)
+
+ return x_masked, mask, ids_restore
+
+ def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
+ """
+ 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
+ Perform per-sample random masking by per-sample shuffling.
+ Per-sample shuffling is done by argsort random noise.
+ x: [N, L, D], sequence
+ """
+
+ N, L, D = x.shape # batch, length, dim
+ if self.use_custom_patch:
+ # # for AS
+ T = 101 # 64,101
+ F = 12 # 8,12
+ # # for ESC
+ # T=50
+ # F=12
+ # for SPC
+ # T=12
+ # F=12
+ else:
+ # ## for AS
+ T = 64
+ F = 8
+ # ## for ESC
+ # T=32
+ # F=8
+ ## for SPC
+ # T=8
+ # F=8
+
+ # mask T
+ x = x.reshape(N, T, F, D)
+ len_keep_T = int(T * (1 - mask_t_prob))
+ noise = torch.rand(N, T, device=x.device) # noise in [0, 1]
+ # sort noise for each sample
+ ids_shuffle = torch.argsort(
+ noise, dim=1
+ ) # ascend: small is keep, large is remove
+ ids_keep = ids_shuffle[:, :len_keep_T]
+ index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D)
+ # x_masked = torch.gather(x, dim=1, index=index)
+ # x_masked = x_masked.reshape(N,len_keep_T*F,D)
+ x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D
+
+ # mask F
+ # x = x.reshape(N, T, F, D)
+ x = x.permute(0, 2, 1, 3) # N T' F D => N F T' D
+ len_keep_F = int(F * (1 - mask_f_prob))
+ noise = torch.rand(N, F, device=x.device) # noise in [0, 1]
+ # sort noise for each sample
+ ids_shuffle = torch.argsort(
+ noise, dim=1
+ ) # ascend: small is keep, large is remove
+ ids_keep = ids_shuffle[:, :len_keep_F]
+ # index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D)
+ index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D)
+ x_masked = torch.gather(x, dim=1, index=index)
+ x_masked = x_masked.permute(0, 2, 1, 3) # N F' T' D => N T' F' D
+ # x_masked = x_masked.reshape(N,len_keep*T,D)
+ x_masked = x_masked.reshape(N, len_keep_F * len_keep_T, D)
+
+ return x_masked, None, None
+
+ def forward_features_mask(self, x, mask_t_prob, mask_f_prob):
+ B = x.shape[0] # 4,1,1024,128
+ x = self.patch_embed(x) # 4, 512, 768
+
+ x = x + self.pos_embed[:, 1:, :]
+ if self.random_masking_2d:
+ x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob)
+ else:
+ x, mask, ids_restore = self.random_masking(x, mask_t_prob)
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
+ cls_tokens = cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = self.pos_drop(x)
+
+ # apply Transformer blocks
+ for blk in self.blocks:
+ x = blk(x)
+
+ if self.global_pool:
+ x = x[:, 1:, :].mean(dim=1) # global pool without cls token
+ outcome = self.fc_norm(x)
+ else:
+ x = self.norm(x)
+ outcome = x[:, 0]
+
+ return outcome
+
+ # overwrite original timm
+ def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0):
+ if mask_t_prob > 0.0 or mask_f_prob > 0.0:
+ x = self.forward_features_mask(
+ x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob
+ )
+ else:
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def vit_small_patch16(**kwargs):
+ model = VisionTransformer(
+ patch_size=16,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs
+ )
+ return model
+
+
+def vit_base_patch16(**kwargs):
+ model = VisionTransformer(
+ patch_size=16,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs
+ )
+ return model
+
+
+def vit_large_patch16(**kwargs):
+ model = VisionTransformer(
+ patch_size=16,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs
+ )
+ return model
+
+
+def vit_huge_patch14(**kwargs):
+ model = VisionTransformer(
+ patch_size=14,
+ embed_dim=1280,
+ depth=32,
+ num_heads=16,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs
+ )
+ return model
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/crop.py b/audioldm2/latent_diffusion/modules/audiomae/util/crop.py
new file mode 100755
index 0000000000000000000000000000000000000000..525e3c783c3d348e593dc89c2b5fb8520918e9ea
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/crop.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch
+
+from torchvision import transforms
+from torchvision.transforms import functional as F
+
+
+class RandomResizedCrop(transforms.RandomResizedCrop):
+ """
+ RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
+ This may lead to results different with torchvision's version.
+ Following BYOL's TF code:
+ https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
+ """
+
+ @staticmethod
+ def get_params(img, scale, ratio):
+ width, height = F._get_image_size(img)
+ area = height * width
+
+ target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
+ log_ratio = torch.log(torch.tensor(ratio))
+ aspect_ratio = torch.exp(
+ torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
+ ).item()
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ w = min(w, width)
+ h = min(h, height)
+
+ i = torch.randint(0, height - h + 1, size=(1,)).item()
+ j = torch.randint(0, width - w + 1, size=(1,)).item()
+
+ return i, j, h, w
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/datasets.py b/audioldm2/latent_diffusion/modules/audiomae/util/datasets.py
new file mode 100755
index 0000000000000000000000000000000000000000..b90f89a7d5f78c31bc9113dd88b632b0c234f10a
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/datasets.py
@@ -0,0 +1,67 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+
+import os
+import PIL
+
+from torchvision import datasets, transforms
+
+from timm.data import create_transform
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+
+
+def build_dataset(is_train, args):
+ transform = build_transform(is_train, args)
+
+ root = os.path.join(args.data_path, "train" if is_train else "val")
+ dataset = datasets.ImageFolder(root, transform=transform)
+
+ print(dataset)
+
+ return dataset
+
+
+def build_transform(is_train, args):
+ mean = IMAGENET_DEFAULT_MEAN
+ std = IMAGENET_DEFAULT_STD
+ # train transform
+ if is_train:
+ # this should always dispatch to transforms_imagenet_train
+ transform = create_transform(
+ input_size=args.input_size,
+ is_training=True,
+ color_jitter=args.color_jitter,
+ auto_augment=args.aa,
+ interpolation="bicubic",
+ re_prob=args.reprob,
+ re_mode=args.remode,
+ re_count=args.recount,
+ mean=mean,
+ std=std,
+ )
+ return transform
+
+ # eval transform
+ t = []
+ if args.input_size <= 224:
+ crop_pct = 224 / 256
+ else:
+ crop_pct = 1.0
+ size = int(args.input_size / crop_pct)
+ t.append(
+ transforms.Resize(
+ size, interpolation=PIL.Image.BICUBIC
+ ), # to maintain same ratio w.r.t. 224 images
+ )
+ t.append(transforms.CenterCrop(args.input_size))
+
+ t.append(transforms.ToTensor())
+ t.append(transforms.Normalize(mean, std))
+ return transforms.Compose(t)
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/lars.py b/audioldm2/latent_diffusion/modules/audiomae/util/lars.py
new file mode 100755
index 0000000000000000000000000000000000000000..fc43923d22cf2c9af4ae9166612c3f3477faf254
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/lars.py
@@ -0,0 +1,60 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# LARS optimizer, implementation from MoCo v3:
+# https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+
+import torch
+
+
+class LARS(torch.optim.Optimizer):
+ """
+ LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
+ """
+
+ def __init__(
+ self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001
+ ):
+ defaults = dict(
+ lr=lr,
+ weight_decay=weight_decay,
+ momentum=momentum,
+ trust_coefficient=trust_coefficient,
+ )
+ super().__init__(params, defaults)
+
+ @torch.no_grad()
+ def step(self):
+ for g in self.param_groups:
+ for p in g["params"]:
+ dp = p.grad
+
+ if dp is None:
+ continue
+
+ if p.ndim > 1: # if not normalization gamma/beta or bias
+ dp = dp.add(p, alpha=g["weight_decay"])
+ param_norm = torch.norm(p)
+ update_norm = torch.norm(dp)
+ one = torch.ones_like(param_norm)
+ q = torch.where(
+ param_norm > 0.0,
+ torch.where(
+ update_norm > 0,
+ (g["trust_coefficient"] * param_norm / update_norm),
+ one,
+ ),
+ one,
+ )
+ dp = dp.mul(q)
+
+ param_state = self.state[p]
+ if "mu" not in param_state:
+ param_state["mu"] = torch.zeros_like(p)
+ mu = param_state["mu"]
+ mu.mul_(g["momentum"]).add_(dp)
+ p.add_(mu, alpha=-g["lr"])
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/lr_decay.py b/audioldm2/latent_diffusion/modules/audiomae/util/lr_decay.py
new file mode 100755
index 0000000000000000000000000000000000000000..e90ed69d7b8d019dbf5d90571541668e2bd8efe8
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/lr_decay.py
@@ -0,0 +1,76 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# ELECTRA https://github.com/google-research/electra
+# BEiT: https://github.com/microsoft/unilm/tree/master/beit
+# --------------------------------------------------------
+
+
+def param_groups_lrd(
+ model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75
+):
+ """
+ Parameter groups for layer-wise lr decay
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
+ """
+ param_group_names = {}
+ param_groups = {}
+
+ num_layers = len(model.blocks) + 1
+
+ layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
+
+ for n, p in model.named_parameters():
+ if not p.requires_grad:
+ continue
+
+ # no decay: all 1D parameters and model specific ones
+ if p.ndim == 1 or n in no_weight_decay_list:
+ g_decay = "no_decay"
+ this_decay = 0.0
+ else:
+ g_decay = "decay"
+ this_decay = weight_decay
+
+ layer_id = get_layer_id_for_vit(n, num_layers)
+ group_name = "layer_%d_%s" % (layer_id, g_decay)
+
+ if group_name not in param_group_names:
+ this_scale = layer_scales[layer_id]
+
+ param_group_names[group_name] = {
+ "lr_scale": this_scale,
+ "weight_decay": this_decay,
+ "params": [],
+ }
+ param_groups[group_name] = {
+ "lr_scale": this_scale,
+ "weight_decay": this_decay,
+ "params": [],
+ }
+
+ param_group_names[group_name]["params"].append(n)
+ param_groups[group_name]["params"].append(p)
+
+ # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
+
+ return list(param_groups.values())
+
+
+def get_layer_id_for_vit(name, num_layers):
+ """
+ Assign a parameter with its layer id
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
+ """
+ if name in ["cls_token", "pos_embed"]:
+ return 0
+ elif name.startswith("patch_embed"):
+ return 0
+ elif name.startswith("blocks"):
+ return int(name.split(".")[1]) + 1
+ else:
+ return num_layers
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/lr_sched.py b/audioldm2/latent_diffusion/modules/audiomae/util/lr_sched.py
new file mode 100755
index 0000000000000000000000000000000000000000..efe184d8e3fb63ec6b4f83375b6ea719985900de
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/lr_sched.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+
+def adjust_learning_rate(optimizer, epoch, args):
+ """Decay the learning rate with half-cycle cosine after warmup"""
+ if epoch < args.warmup_epochs:
+ lr = args.lr * epoch / args.warmup_epochs
+ else:
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (
+ 1.0
+ + math.cos(
+ math.pi
+ * (epoch - args.warmup_epochs)
+ / (args.epochs - args.warmup_epochs)
+ )
+ )
+ for param_group in optimizer.param_groups:
+ if "lr_scale" in param_group:
+ param_group["lr"] = lr * param_group["lr_scale"]
+ else:
+ param_group["lr"] = lr
+ return lr
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/misc.py b/audioldm2/latent_diffusion/modules/audiomae/util/misc.py
new file mode 100755
index 0000000000000000000000000000000000000000..74184e09e23e0e174350b894b0cff29600c18b71
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/misc.py
@@ -0,0 +1,453 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# BEiT: https://github.com/microsoft/unilm/tree/master/beit
+# --------------------------------------------------------
+
+import builtins
+import datetime
+import os
+import time
+from collections import defaultdict, deque
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+from torch._six import inf
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value,
+ )
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if v is None:
+ continue
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError(
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
+ )
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append("{}: {}".format(name, str(meter)))
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ""
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
+ data_time = SmoothedValue(fmt="{avg:.4f}")
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
+ log_msg = [
+ header,
+ "[{0" + space_fmt + "}/{1}]",
+ "eta: {eta}",
+ "{meters}",
+ "time: {time}",
+ "data: {data}",
+ ]
+ if torch.cuda.is_available():
+ log_msg.append("max mem: {memory:.0f}")
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(
+ log_msg.format(
+ i,
+ len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB,
+ )
+ )
+ else:
+ print(
+ log_msg.format(
+ i,
+ len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ )
+ )
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print(
+ "{} Total time: {} ({:.4f} s / it)".format(
+ header, total_time_str, total_time / len(iterable)
+ )
+ )
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ builtin_print = builtins.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ force = force or (get_world_size() > 8)
+ if is_master or force:
+ now = datetime.datetime.now().time()
+ builtin_print("[{}] ".format(now), end="") # print with time stamp
+ builtin_print(*args, **kwargs)
+
+ builtins.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if args.dist_on_itp:
+ args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
+ args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
+ args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
+ args.dist_url = "tcp://%s:%s" % (
+ os.environ["MASTER_ADDR"],
+ os.environ["MASTER_PORT"],
+ )
+ os.environ["LOCAL_RANK"] = str(args.gpu)
+ os.environ["RANK"] = str(args.rank)
+ os.environ["WORLD_SIZE"] = str(args.world_size)
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
+ elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ["WORLD_SIZE"])
+ args.gpu = int(os.environ["LOCAL_RANK"])
+ elif "SLURM_PROCID" in os.environ:
+ args.rank = int(os.environ["SLURM_PROCID"])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print("Not using distributed mode")
+ setup_for_distributed(is_master=True) # hack
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = "nccl"
+ print(
+ "| distributed init (rank {}): {}, gpu {}".format(
+ args.rank, args.dist_url, args.gpu
+ ),
+ flush=True,
+ )
+ torch.distributed.init_process_group(
+ backend=args.dist_backend,
+ init_method=args.dist_url,
+ world_size=args.world_size,
+ rank=args.rank,
+ )
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+class NativeScalerWithGradNormCount:
+ state_dict_key = "amp_scaler"
+
+ def __init__(self):
+ self._scaler = torch.cuda.amp.GradScaler()
+
+ def __call__(
+ self,
+ loss,
+ optimizer,
+ clip_grad=None,
+ parameters=None,
+ create_graph=False,
+ update_grad=True,
+ ):
+ self._scaler.scale(loss).backward(create_graph=create_graph)
+ if update_grad:
+ if clip_grad is not None:
+ assert parameters is not None
+ self._scaler.unscale_(
+ optimizer
+ ) # unscale the gradients of optimizer's assigned params in-place
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+ else:
+ self._scaler.unscale_(optimizer)
+ norm = get_grad_norm_(parameters)
+ self._scaler.step(optimizer)
+ self._scaler.update()
+ else:
+ norm = None
+ return norm
+
+ def state_dict(self):
+ return self._scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self._scaler.load_state_dict(state_dict)
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = [p for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(parameters) == 0:
+ return torch.tensor(0.0)
+ device = parameters[0].grad.device
+ if norm_type == inf:
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+ else:
+ total_norm = torch.norm(
+ torch.stack(
+ [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
+ ),
+ norm_type,
+ )
+ return total_norm
+
+
+def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
+ output_dir = Path(args.output_dir)
+ epoch_name = str(epoch)
+ if loss_scaler is not None:
+ checkpoint_paths = [output_dir / ("checkpoint-%s.pth" % epoch_name)]
+ for checkpoint_path in checkpoint_paths:
+ to_save = {
+ "model": model_without_ddp.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "epoch": epoch,
+ "scaler": loss_scaler.state_dict(),
+ "args": args,
+ }
+
+ save_on_master(to_save, checkpoint_path)
+ else:
+ client_state = {"epoch": epoch}
+ model.save_checkpoint(
+ save_dir=args.output_dir,
+ tag="checkpoint-%s" % epoch_name,
+ client_state=client_state,
+ )
+
+
+def load_model(args, model_without_ddp, optimizer, loss_scaler):
+ if args.resume:
+ if args.resume.startswith("https"):
+ checkpoint = torch.hub.load_state_dict_from_url(
+ args.resume, map_location="cpu", check_hash=True
+ )
+ else:
+ checkpoint = torch.load(args.resume, map_location="cpu")
+ model_without_ddp.load_state_dict(checkpoint["model"])
+ print("Resume checkpoint %s" % args.resume)
+ if (
+ "optimizer" in checkpoint
+ and "epoch" in checkpoint
+ and not (hasattr(args, "eval") and args.eval)
+ ):
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ args.start_epoch = checkpoint["epoch"] + 1
+ if "scaler" in checkpoint:
+ loss_scaler.load_state_dict(checkpoint["scaler"])
+ print("With optim & sched!")
+
+
+def all_reduce_mean(x):
+ world_size = get_world_size()
+ if world_size > 1:
+ x_reduce = torch.tensor(x).cuda()
+ dist.all_reduce(x_reduce)
+ x_reduce /= world_size
+ return x_reduce.item()
+ else:
+ return x
+
+
+# utils
+@torch.no_grad()
+def concat_all_gather(tensor):
+ """
+ Performs all_gather operation on the provided tensors.
+ *** Warning ***: torch.distributed.all_gather has no gradient.
+ """
+ tensors_gather = [
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
+ ]
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+ output = torch.cat(tensors_gather, dim=0)
+ return output
+
+
+def merge_vmae_to_avmae(avmae_state_dict, vmae_ckpt):
+ # keys_to_copy=['pos_embed','patch_embed']
+ # replaced=0
+
+ vmae_ckpt["cls_token"] = vmae_ckpt["cls_token_v"]
+ vmae_ckpt["mask_token"] = vmae_ckpt["mask_token_v"]
+
+ # pos_emb % not trainable, use default
+ pos_embed_v = vmae_ckpt["pos_embed_v"] # 1,589,768
+ pos_embed = pos_embed_v[:, 1:, :] # 1,588,768
+ cls_embed = pos_embed_v[:, 0, :].unsqueeze(1)
+ pos_embed = pos_embed.reshape(1, 2, 14, 14, 768).sum(dim=1) # 1, 14, 14, 768
+ print("Position interpolate from 14,14 to 64,8")
+ pos_embed = pos_embed.permute(0, 3, 1, 2) # 1, 14,14,768 -> 1,768,14,14
+ pos_embed = torch.nn.functional.interpolate(
+ pos_embed, size=(64, 8), mode="bicubic", align_corners=False
+ )
+ pos_embed = pos_embed.permute(0, 2, 3, 1).flatten(
+ 1, 2
+ ) # 1, 14, 14, 768 => 1, 196,768
+ pos_embed = torch.cat((cls_embed, pos_embed), dim=1)
+ assert vmae_ckpt["pos_embed"].shape == pos_embed.shape
+ vmae_ckpt["pos_embed"] = pos_embed
+ # patch_emb
+ # aggregate 3 channels in video-rgb ckpt to 1 channel for audio
+ v_weight = vmae_ckpt["patch_embed_v.proj.weight"] # 768,3,2,16,16
+ new_proj_weight = torch.nn.Parameter(v_weight.sum(dim=2).sum(dim=1).unsqueeze(1))
+ assert new_proj_weight.shape == vmae_ckpt["patch_embed.proj.weight"].shape
+ vmae_ckpt["patch_embed.proj.weight"] = new_proj_weight
+ vmae_ckpt["patch_embed.proj.bias"] = vmae_ckpt["patch_embed_v.proj.bias"]
+
+ # hack
+ vmae_ckpt["norm.weight"] = vmae_ckpt["norm_v.weight"]
+ vmae_ckpt["norm.bias"] = vmae_ckpt["norm_v.bias"]
+
+ # replace transformer encoder
+ for k, v in vmae_ckpt.items():
+ if k.startswith("blocks."):
+ kk = k.replace("blocks.", "blocks_v.")
+ vmae_ckpt[k] = vmae_ckpt[kk]
+ elif k.startswith("blocks_v."):
+ pass
+ else:
+ print(k)
+ print(k)
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/patch_embed.py b/audioldm2/latent_diffusion/modules/audiomae/util/patch_embed.py
new file mode 100755
index 0000000000000000000000000000000000000000..ac1e4d436c6f79aef9bf1de32cdac5d4f037c775
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/patch_embed.py
@@ -0,0 +1,127 @@
+import torch
+import torch.nn as nn
+from timm.models.layers import to_2tuple
+
+
+class PatchEmbed_org(nn.Module):
+ """Image to Patch Embedding"""
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
+ )
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ # assert H == self.img_size[0] and W == self.img_size[1], \
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x)
+ y = x.flatten(2).transpose(1, 2)
+ return y
+
+
+class PatchEmbed_new(nn.Module):
+ """Flexible Image to Patch Embedding"""
+
+ def __init__(
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
+ ):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ stride = to_2tuple(stride)
+
+ self.img_size = img_size
+ self.patch_size = patch_size
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=patch_size, stride=stride
+ ) # with overlapped patches
+ # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
+ # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
+ self.patch_hw = (h, w)
+ self.num_patches = h * w
+
+ def get_output_shape(self, img_size):
+ # todo: don't be lazy..
+ return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ # assert H == self.img_size[0] and W == self.img_size[1], \
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ # x = self.proj(x).flatten(2).transpose(1, 2)
+ x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12
+ x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212
+ x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768
+ return x
+
+
+class PatchEmbed3D_new(nn.Module):
+ """Flexible Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ video_size=(16, 224, 224),
+ patch_size=(2, 16, 16),
+ in_chans=3,
+ embed_dim=768,
+ stride=(2, 16, 16),
+ ):
+ super().__init__()
+
+ self.video_size = video_size
+ self.patch_size = patch_size
+ self.in_chans = in_chans
+
+ self.proj = nn.Conv3d(
+ in_chans, embed_dim, kernel_size=patch_size, stride=stride
+ )
+ _, _, t, h, w = self.get_output_shape(video_size) # n, emb_dim, h, w
+ self.patch_thw = (t, h, w)
+ self.num_patches = t * h * w
+
+ def get_output_shape(self, video_size):
+ # todo: don't be lazy..
+ return self.proj(
+ torch.randn(1, self.in_chans, video_size[0], video_size[1], video_size[2])
+ ).shape
+
+ def forward(self, x):
+ B, C, T, H, W = x.shape
+ x = self.proj(x) # 32, 3, 16, 224, 224 -> 32, 768, 8, 14, 14
+ x = x.flatten(2) # 32, 768, 1568
+ x = x.transpose(1, 2) # 32, 768, 1568 -> 32, 1568, 768
+ return x
+
+
+if __name__ == "__main__":
+ # patch_emb = PatchEmbed_new(img_size=224, patch_size=16, in_chans=1, embed_dim=64, stride=(16,16))
+ # input = torch.rand(8,1,1024,128)
+ # output = patch_emb(input)
+ # print(output.shape) # (8,512,64)
+
+ patch_emb = PatchEmbed3D_new(
+ video_size=(6, 224, 224),
+ patch_size=(2, 16, 16),
+ in_chans=3,
+ embed_dim=768,
+ stride=(2, 16, 16),
+ )
+ input = torch.rand(8, 3, 6, 224, 224)
+ output = patch_emb(input)
+ print(output.shape) # (8,64)
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/pos_embed.py b/audioldm2/latent_diffusion/modules/audiomae/util/pos_embed.py
new file mode 100755
index 0000000000000000000000000000000000000000..2d9177ed98dffcf35264f38aff94e7f00fb50abf
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/pos_embed.py
@@ -0,0 +1,206 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# Position embedding utils
+# --------------------------------------------------------
+
+import numpy as np
+
+import torch
+
+
+# --------------------------------------------------------
+# 2D sine-cosine position embedding
+# References:
+# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
+# MoCo v3: https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ # omega = np.arange(embed_dim // 2, dtype=np.float)
+ omega = np.arange(embed_dim // 2, dtype=float)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+# --------------------------------------------------------
+# Interpolate position embeddings for high-resolution
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+def interpolate_pos_embed(model, checkpoint_model):
+ if "pos_embed" in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches**0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print(
+ "Position interpolate from %dx%d to %dx%d"
+ % (orig_size, orig_size, new_size, new_size)
+ )
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(
+ -1, orig_size, orig_size, embedding_size
+ ).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens,
+ size=(new_size, new_size),
+ mode="bicubic",
+ align_corners=False,
+ )
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model["pos_embed"] = new_pos_embed
+
+
+def interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, new_size):
+ if "pos_embed" in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ # orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ # new_size = int(num_patches ** 0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print(
+ "Position interpolate from %dx%d to %dx%d"
+ % (orig_size[0], orig_size[1], new_size[0], new_size[1])
+ )
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(
+ -1, orig_size[0], orig_size[1], embedding_size
+ ).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens,
+ size=(new_size[0], new_size[1]),
+ mode="bicubic",
+ align_corners=False,
+ )
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model["pos_embed"] = new_pos_embed
+
+
+def interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_size):
+ if "pos_embed" in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ model.pos_embed.shape[-2] - num_patches
+ if orig_size != new_size:
+ print(
+ "Position interpolate from %dx%d to %dx%d"
+ % (orig_size[0], orig_size[1], new_size[0], new_size[1])
+ )
+ # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ cls_token = pos_embed_checkpoint[:, 0, :].unsqueeze(1)
+ pos_tokens = pos_embed_checkpoint[:, 1:, :] # remove
+ pos_tokens = pos_tokens.reshape(
+ -1, orig_size[0], orig_size[1], embedding_size
+ ) # .permute(0, 3, 1, 2)
+ # pos_tokens = torch.nn.functional.interpolate(
+ # pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False)
+
+ # pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ pos_tokens = pos_tokens[:, :, : new_size[1], :] # assume only time diff
+ pos_tokens = pos_tokens.flatten(1, 2)
+ new_pos_embed = torch.cat((cls_token, pos_tokens), dim=1)
+ checkpoint_model["pos_embed"] = new_pos_embed
+
+
+def interpolate_patch_embed_audio(
+ model,
+ checkpoint_model,
+ orig_channel,
+ new_channel=1,
+ kernel_size=(16, 16),
+ stride=(16, 16),
+ padding=(0, 0),
+):
+ if orig_channel != new_channel:
+ if "patch_embed.proj.weight" in checkpoint_model:
+ # aggregate 3 channels in rgb ckpt to 1 channel for audio
+ new_proj_weight = torch.nn.Parameter(
+ torch.sum(checkpoint_model["patch_embed.proj.weight"], dim=1).unsqueeze(
+ 1
+ )
+ )
+ checkpoint_model["patch_embed.proj.weight"] = new_proj_weight
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/stat.py b/audioldm2/latent_diffusion/modules/audiomae/util/stat.py
new file mode 100755
index 0000000000000000000000000000000000000000..3f8137249503f6eaa25c3170fe5ef6b87f187347
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/stat.py
@@ -0,0 +1,76 @@
+import numpy as np
+from scipy import stats
+from sklearn import metrics
+import torch
+
+
+def d_prime(auc):
+ standard_normal = stats.norm()
+ d_prime = standard_normal.ppf(auc) * np.sqrt(2.0)
+ return d_prime
+
+
+@torch.no_grad()
+def concat_all_gather(tensor):
+ """
+ Performs all_gather operation on the provided tensors.
+ *** Warning ***: torch.distributed.all_gather has no gradient.
+ """
+ tensors_gather = [
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
+ ]
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+ output = torch.cat(tensors_gather, dim=0)
+ return output
+
+
+def calculate_stats(output, target):
+ """Calculate statistics including mAP, AUC, etc.
+
+ Args:
+ output: 2d array, (samples_num, classes_num)
+ target: 2d array, (samples_num, classes_num)
+
+ Returns:
+ stats: list of statistic of each class.
+ """
+
+ classes_num = target.shape[-1]
+ stats = []
+
+ # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet
+ acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1))
+
+ # Class-wise statistics
+ for k in range(classes_num):
+ # Average precision
+ avg_precision = metrics.average_precision_score(
+ target[:, k], output[:, k], average=None
+ )
+
+ # AUC
+ # auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None)
+
+ # Precisions, recalls
+ (precisions, recalls, thresholds) = metrics.precision_recall_curve(
+ target[:, k], output[:, k]
+ )
+
+ # FPR, TPR
+ (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k])
+
+ save_every_steps = 1000 # Sample statistics to reduce size
+ dict = {
+ "precisions": precisions[0::save_every_steps],
+ "recalls": recalls[0::save_every_steps],
+ "AP": avg_precision,
+ "fpr": fpr[0::save_every_steps],
+ "fnr": 1.0 - tpr[0::save_every_steps],
+ # 'auc': auc,
+ # note acc is not class-wise, this is just to keep consistent with other metrics
+ "acc": acc,
+ }
+ stats.append(dict)
+
+ return stats
diff --git a/audioldm2/latent_diffusion/modules/diffusionmodules/__init__.py b/audioldm2/latent_diffusion/modules/diffusionmodules/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/latent_diffusion/modules/diffusionmodules/model.py b/audioldm2/latent_diffusion/modules/diffusionmodules/model.py
new file mode 100755
index 0000000000000000000000000000000000000000..851f8dd28e80046c5e3c9d95bd37726024f1367c
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/diffusionmodules/model.py
@@ -0,0 +1,1069 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+
+from audioldm2.latent_diffusion.util import instantiate_from_config
+from audioldm2.latent_diffusion.modules.attention import LinearAttention
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class UpsampleTimeStride4(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=5, stride=1, padding=2
+ )
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # Do time downsampling here
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
+ )
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class DownsampleTimeStride4(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # Do time downsampling here
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1
+ )
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w).contiguous()
+ q = q.permute(0, 2, 1).contiguous() # b,hw,c
+ k = k.reshape(b, c, h * w).contiguous() # b,c,hw
+ w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w).contiguous()
+ w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(
+ v, w_
+ ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w).contiguous()
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
+ # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Model(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ use_timestep=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch * 4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList(
+ [
+ torch.nn.Linear(self.ch, self.temb_ch),
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
+ ]
+ )
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ skip_in = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch * in_ch_mult[i_level]
+ block.append(
+ ResnetBlock(
+ in_channels=block_in + skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x, t=None, context=None):
+ # assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb
+ )
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ downsample_time_stride4_levels=[],
+ **ignore_kwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.downsample_time_stride4_levels = downsample_time_stride4_levels
+
+ if len(self.downsample_time_stride4_levels) > 0:
+ assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
+ "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
+ % str(self.num_resolutions)
+ )
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ if i_level in self.downsample_time_stride4_levels:
+ down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)
+ else:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ tanh_out=False,
+ use_linear_attn=False,
+ downsample_time_stride4_levels=[],
+ attn_type="vanilla",
+ **ignorekwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+ self.downsample_time_stride4_levels = downsample_time_stride4_levels
+
+ if len(self.downsample_time_stride4_levels) > 0:
+ assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
+ "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
+ % str(self.num_resolutions)
+ )
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ # print(
+ # "Working with z of shape {} = {} dimensions.".format(
+ # self.z_shape, np.prod(self.z_shape)
+ # )
+ # )
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
+ )
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ if i_level - 1 in self.downsample_time_stride4_levels:
+ up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)
+ else:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, z):
+ # assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList(
+ [
+ nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(
+ in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0,
+ dropout=0.0,
+ ),
+ ResnetBlock(
+ in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0,
+ dropout=0.0,
+ ),
+ ResnetBlock(
+ in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0,
+ dropout=0.0,
+ ),
+ nn.Conv2d(2 * in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True),
+ ]
+ )
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1, 2, 3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ ch,
+ num_res_blocks,
+ resolution,
+ ch_mult=(2, 2),
+ dropout=0.0,
+ ):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(
+ in_channels, mid_channels, kernel_size=3, stride=1, padding=1
+ )
+ self.res_block1 = nn.ModuleList(
+ [
+ ResnetBlock(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0,
+ )
+ for _ in range(depth)
+ ]
+ )
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList(
+ [
+ ResnetBlock(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.conv_out = nn.Conv2d(
+ mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(
+ x,
+ size=(
+ int(round(x.shape[2] * self.factor)),
+ int(round(x.shape[3] * self.factor)),
+ ),
+ )
+ x = self.attn(x).contiguous()
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ ch,
+ resolution,
+ out_ch,
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ ch_mult=(1, 2, 4, 8),
+ rescale_factor=1.0,
+ rescale_module_depth=1,
+ ):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ num_res_blocks=num_res_blocks,
+ ch=ch,
+ ch_mult=ch_mult,
+ z_channels=intermediate_chn,
+ double_z=False,
+ resolution=resolution,
+ attn_resolutions=attn_resolutions,
+ dropout=dropout,
+ resamp_with_conv=resamp_with_conv,
+ out_ch=None,
+ )
+ self.rescaler = LatentRescaler(
+ factor=rescale_factor,
+ in_channels=intermediate_chn,
+ mid_channels=intermediate_chn,
+ out_channels=out_ch,
+ depth=rescale_module_depth,
+ )
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(
+ self,
+ z_channels,
+ out_ch,
+ resolution,
+ num_res_blocks,
+ attn_resolutions,
+ ch,
+ ch_mult=(1, 2, 4, 8),
+ dropout=0.0,
+ resamp_with_conv=True,
+ rescale_factor=1.0,
+ rescale_module_depth=1,
+ ):
+ super().__init__()
+ tmp_chn = z_channels * ch_mult[-1]
+ self.decoder = Decoder(
+ out_ch=out_ch,
+ z_channels=tmp_chn,
+ attn_resolutions=attn_resolutions,
+ dropout=dropout,
+ resamp_with_conv=resamp_with_conv,
+ in_channels=None,
+ num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult,
+ resolution=resolution,
+ ch=ch,
+ )
+ self.rescaler = LatentRescaler(
+ factor=rescale_factor,
+ in_channels=z_channels,
+ mid_channels=tmp_chn,
+ out_channels=tmp_chn,
+ depth=rescale_module_depth,
+ )
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size // in_size)) + 1
+ factor_up = 1.0 + (out_size % in_size)
+ print(
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
+ )
+ self.rescaler = LatentRescaler(
+ factor=factor_up,
+ in_channels=in_channels,
+ mid_channels=2 * in_channels,
+ out_channels=in_channels,
+ )
+ self.decoder = Decoder(
+ out_ch=out_channels,
+ resolution=out_size,
+ z_channels=in_channels,
+ num_res_blocks=2,
+ attn_resolutions=[],
+ in_channels=None,
+ ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)],
+ )
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(
+ f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
+ )
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=4, stride=2, padding=1
+ )
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor == 1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(
+ x, mode=self.mode, align_corners=False, scale_factor=scale_factor
+ )
+ return x
+
+
+class FirstStagePostProcessor(nn.Module):
+ def __init__(
+ self,
+ ch_mult: list,
+ in_channels,
+ pretrained_model: nn.Module = None,
+ reshape=False,
+ n_channels=None,
+ dropout=0.0,
+ pretrained_config=None,
+ ):
+ super().__init__()
+ if pretrained_config is None:
+ assert (
+ pretrained_model is not None
+ ), 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.pretrained_model = pretrained_model
+ else:
+ assert (
+ pretrained_config is not None
+ ), 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.instantiate_pretrained(pretrained_config)
+
+ self.do_reshape = reshape
+
+ if n_channels is None:
+ n_channels = self.pretrained_model.encoder.ch
+
+ self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
+ self.proj = nn.Conv2d(
+ in_channels, n_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ blocks = []
+ downs = []
+ ch_in = n_channels
+ for m in ch_mult:
+ blocks.append(
+ ResnetBlock(
+ in_channels=ch_in, out_channels=m * n_channels, dropout=dropout
+ )
+ )
+ ch_in = m * n_channels
+ downs.append(Downsample(ch_in, with_conv=False))
+
+ self.model = nn.ModuleList(blocks)
+ self.downsampler = nn.ModuleList(downs)
+
+ def instantiate_pretrained(self, config):
+ model = instantiate_from_config(config)
+ self.pretrained_model = model.eval()
+ # self.pretrained_model.train = False
+ for param in self.pretrained_model.parameters():
+ param.requires_grad = False
+
+ @torch.no_grad()
+ def encode_with_pretrained(self, x):
+ c = self.pretrained_model.encode(x)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ return c
+
+ def forward(self, x):
+ z_fs = self.encode_with_pretrained(x)
+ z = self.proj_norm(z_fs)
+ z = self.proj(z)
+ z = nonlinearity(z)
+
+ for submodel, downmodel in zip(self.model, self.downsampler):
+ z = submodel(z, temb=None)
+ z = downmodel(z)
+
+ if self.do_reshape:
+ z = rearrange(z, "b c h w -> b (h w) c")
+ return z
diff --git a/audioldm2/latent_diffusion/modules/diffusionmodules/openaimodel.py b/audioldm2/latent_diffusion/modules/diffusionmodules/openaimodel.py
new file mode 100755
index 0000000000000000000000000000000000000000..e006e5a332c3cde5f4e221f003b270d86b34e933
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,1103 @@
+from abc import abstractmethod
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from audioldm2.latent_diffusion.modules.diffusionmodules.util import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from audioldm2.latent_diffusion.modules.attention import SpatialTransformer
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
+ )
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1).contiguous() # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context_list=None, mask_list=None):
+ # The first spatial transformer block does not have context
+ spatial_transformer_id = 0
+ context_list = [None] + context_list
+ mask_list = [None] + mask_list
+
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ if spatial_transformer_id >= len(context_list):
+ context, mask = None, None
+ else:
+ context, mask = (
+ context_list[spatial_transformer_id],
+ mask_list[spatial_transformer_id],
+ )
+
+ x = layer(x, context, mask=mask)
+ spatial_transformer_id += 1
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(
+ dims, self.channels, self.out_channels, 3, padding=padding
+ )
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class TransposedUpsample(nn.Module):
+ "Learned 2x upsampling without padding"
+
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(
+ self.channels, self.out_channels, kernel_size=ks, stride=2
+ )
+
+ def forward(self, x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=padding,
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(
+ self._forward, (x,), self.parameters(), True
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ # return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1).contiguous()
+ qkv = self.qkv(self.norm(x)).contiguous()
+ h = self.attention(qkv).contiguous()
+ h = self.proj_out(h).contiguous()
+ return (x + h).reshape(b, c, *spatial).contiguous()
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial**2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = (
+ qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1)
+ )
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length).contiguous()
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum(
+ "bts,bcs->bct",
+ weight,
+ v.reshape(bs * self.n_heads, ch, length).contiguous(),
+ )
+ return a.reshape(bs, -1, length).contiguous()
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ extra_sa_layer=True,
+ num_classes=None,
+ extra_film_condition_dim=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=True, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ ):
+ super().__init__()
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert (
+ num_head_channels != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ if num_head_channels == -1:
+ assert (
+ num_heads != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.extra_film_condition_dim = extra_film_condition_dim
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ # assert not (
+ # self.num_classes is not None and self.extra_film_condition_dim is not None
+ # ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim."
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ self.use_extra_film_by_concat = self.extra_film_condition_dim is not None
+
+ if self.extra_film_condition_dim is not None:
+ self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim)
+ print(
+ "+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. "
+ % self.extra_film_condition_dim
+ )
+
+ if context_dim is not None and not use_spatial_transformer:
+ assert (
+ use_spatial_transformer
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
+
+ if context_dim is not None and not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ elif context_dim is None:
+ context_dim = [None] # At least use one spatial transformer
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim
+ if (not self.use_extra_film_by_concat)
+ else time_embed_dim * 2,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ dim_head = (
+ ch // num_heads
+ if use_spatial_transformer
+ else num_head_channels
+ )
+ if extra_sa_layer:
+ layers.append(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=None,
+ )
+ )
+ for context_dim_id in range(len(context_dim)):
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ if not use_spatial_transformer
+ else SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim[context_dim_id],
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim
+ if (not self.use_extra_film_by_concat)
+ else time_embed_dim * 2,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ middle_layers = [
+ ResBlock(
+ ch,
+ time_embed_dim
+ if (not self.use_extra_film_by_concat)
+ else time_embed_dim * 2,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ if extra_sa_layer:
+ middle_layers.append(
+ SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=None
+ )
+ )
+ for context_dim_id in range(len(context_dim)):
+ middle_layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ if not use_spatial_transformer
+ else SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim[context_dim_id],
+ )
+ )
+ middle_layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim
+ if (not self.use_extra_film_by_concat)
+ else time_embed_dim * 2,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ )
+ self.middle_block = TimestepEmbedSequential(*middle_layers)
+
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim
+ if (not self.use_extra_film_by_concat)
+ else time_embed_dim * 2,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = (
+ ch // num_heads
+ if use_spatial_transformer
+ else num_head_channels
+ )
+ if extra_sa_layer:
+ layers.append(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=None,
+ )
+ )
+ for context_dim_id in range(len(context_dim)):
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ if not use_spatial_transformer
+ else SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim[context_dim_id],
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim
+ if (not self.use_extra_film_by_concat)
+ else time_embed_dim * 2,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ self.shape_reported = False
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(
+ self,
+ x,
+ timesteps=None,
+ y=None,
+ context_list=None,
+ context_attn_mask_list=None,
+ **kwargs,
+ ):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ if not self.shape_reported:
+ # print("The shape of UNet input is", x.size())
+ self.shape_reported = True
+
+ assert (y is not None) == (
+ self.num_classes is not None or self.extra_film_condition_dim is not None
+ ), "must specify y if and only if the model is class-conditional or film embedding conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ # if self.num_classes is not None:
+ # assert y.shape == (x.shape[0],)
+ # emb = emb + self.label_emb(y)
+
+ if self.use_extra_film_by_concat:
+ emb = th.cat([emb, self.film_emb(y)], dim=-1)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context_list, context_attn_mask_list)
+ hs.append(h)
+ h = self.middle_block(h, emb, context_list, context_attn_mask_list)
+ for module in self.output_blocks:
+ concate_tensor = hs.pop()
+ h = th.cat([h, concate_tensor], dim=1)
+ h = module(h, emb, context_list, context_attn_mask_list)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
+
+
+class EncoderUNetModel(nn.Module):
+ """
+ The half UNet model with attention and timestep embedding.
+ For usage, see UNet.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ pool="adaptive",
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+ self.pool = pool
+ if pool == "adaptive":
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d((1, 1)),
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
+ nn.Flatten(),
+ )
+ elif pool == "attention":
+ assert num_head_channels != -1
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ AttentionPool2d(
+ (image_size // ds), ch, num_head_channels, out_channels
+ ),
+ )
+ elif pool == "spatial":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ nn.ReLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ elif pool == "spatial_v2":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ normalization(2048),
+ nn.SiLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ else:
+ raise NotImplementedError(f"Unexpected {pool} pooling")
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :return: an [N x K] Tensor of outputs.
+ """
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+
+ results = []
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = self.middle_block(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = th.cat(results, axis=-1)
+ return self.out(h)
+ else:
+ h = h.type(x.dtype)
+ return self.out(h)
diff --git a/audioldm2/latent_diffusion/modules/diffusionmodules/util.py b/audioldm2/latent_diffusion/modules/diffusionmodules/util.py
new file mode 100755
index 0000000000000000000000000000000000000000..0d486f919a7ccc0586bc40225dac0ffb33aed01c
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/diffusionmodules/util.py
@@ -0,0 +1,294 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+from audioldm2.latent_diffusion.util import instantiate_from_config
+
+
+def make_beta_schedule(
+ schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
+):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
+ )
+ ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(
+ linear_start, linear_end, n_timestep, dtype=torch.float64
+ )
+ elif schedule == "sqrt":
+ betas = (
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ ** 0.5
+ )
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+
+def make_ddim_timesteps(
+ ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
+):
+ if ddim_discr_method == "uniform":
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == "quad":
+ ddim_timesteps = (
+ (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
+ ).astype(int)
+ else:
+ raise NotImplementedError(
+ f'There is no ddim discretization method called "{ddim_discr_method}"'
+ )
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f"Selected timesteps for ddim sampler: {steps_out}")
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt(
+ (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
+ )
+ if verbose:
+ print(
+ f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
+ )
+ print(
+ f"For the chosen value of eta, which is {eta}, "
+ f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
+ )
+ return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t).contiguous()
+ return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous()
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+ )
+ else:
+ embedding = repeat(timesteps, "b -> b d", d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
+ shape[0], *((1,) * (len(shape) - 1))
+ )
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
diff --git a/audioldm2/latent_diffusion/modules/ema.py b/audioldm2/latent_diffusion/modules/ema.py
new file mode 100755
index 0000000000000000000000000000000000000000..880ca3d205d9b4d7450e146930a93f2e63c58b70
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/ema.py
@@ -0,0 +1,82 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError("Decay must be between 0 and 1")
+
+ self.m_name2s_name = {}
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer(
+ "num_updates",
+ torch.tensor(0, dtype=torch.int)
+ if use_num_upates
+ else torch.tensor(-1, dtype=torch.int),
+ )
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace(".", "")
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+
+ self.collected_params = []
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(
+ one_minus_decay * (shadow_params[sname] - m_param[key])
+ )
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/audioldm2/latent_diffusion/modules/encoders/__init__.py b/audioldm2/latent_diffusion/modules/encoders/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/latent_diffusion/modules/encoders/modules.py b/audioldm2/latent_diffusion/modules/encoders/modules.py
new file mode 100755
index 0000000000000000000000000000000000000000..7a72339840c0c3b667e907ea07ee7cb755eb66fd
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/encoders/modules.py
@@ -0,0 +1,736 @@
+import torch
+import logging
+import torch.nn as nn
+from audioldm2.clap.open_clip import create_model
+from audioldm2.clap.training.data import get_audio_features
+import torchaudio
+from transformers import RobertaTokenizer, AutoTokenizer, T5EncoderModel
+import torch.nn.functional as F
+from audioldm2.latent_diffusion.modules.audiomae.AudioMAE import Vanilla_AudioMAE
+from audioldm2.latent_diffusion.modules.phoneme_encoder.encoder import TextEncoder
+
+from transformers import AutoTokenizer, T5Config
+
+from audioldm2.audiomae_gen.sequence_input import Sequence2AudioMAE
+import numpy as np
+
+"""
+The model forward function can return three types of data:
+1. tensor: used directly as conditioning signal
+2. dict: where there is a main key as condition, there are also other key that you can use to pass loss function and itermediate result. etc.
+3. list: the length is 2, in which the first element is tensor, the second element is attntion mask.
+
+The output shape for the cross attention condition should be:
+x,x_mask = [bs, seq_len, emb_dim], [bs, seq_len]
+
+All the returned data, in which will be used as diffusion input, will need to be in float type
+"""
+
+
+class PhonemeEncoder(nn.Module):
+ def __init__(self, vocabs_size=41, pad_length=250, pad_token_id=None):
+ super().__init__()
+ """
+ encoder = PhonemeEncoder(40)
+ data = torch.randint(0, 39, (2, 250))
+ output = encoder(data)
+ import ipdb;ipdb.set_trace()
+ """
+ assert pad_token_id is not None
+
+ self.device = None
+ self.PAD_LENGTH = int(pad_length)
+ self.pad_token_id = pad_token_id
+ self.pad_token_sequence = torch.tensor([self.pad_token_id] * self.PAD_LENGTH)
+
+ self.text_encoder = TextEncoder(
+ n_vocab=vocabs_size,
+ out_channels=192,
+ hidden_channels=192,
+ filter_channels=768,
+ n_heads=2,
+ n_layers=6,
+ kernel_size=3,
+ p_dropout=0.1,
+ )
+
+ self.learnable_positional_embedding = torch.nn.Parameter(
+ torch.zeros((1, 192, self.PAD_LENGTH))
+ ) # [batchsize, seqlen, padlen]
+ self.learnable_positional_embedding.requires_grad = True
+
+ # Required
+ def get_unconditional_condition(self, batchsize):
+ unconditional_tokens = self.pad_token_sequence.expand(
+ batchsize, self.PAD_LENGTH
+ )
+ return self(unconditional_tokens) # Need to return float type
+
+ # def get_unconditional_condition(self, batchsize):
+
+ # hidden_state = torch.zeros((batchsize, self.PAD_LENGTH, 192)).to(self.device)
+ # attention_mask = torch.ones((batchsize, self.PAD_LENGTH)).to(self.device)
+ # return [hidden_state, attention_mask] # Need to return float type
+
+ def _get_src_mask(self, phoneme):
+ src_mask = phoneme != self.pad_token_id
+ return src_mask
+
+ def _get_src_length(self, phoneme):
+ src_mask = self._get_src_mask(phoneme)
+ length = torch.sum(src_mask, dim=-1)
+ return length
+
+ # def make_empty_condition_unconditional(self, src_length, text_emb, attention_mask):
+ # # src_length: [bs]
+ # # text_emb: [bs, 192, pad_length]
+ # # attention_mask: [bs, pad_length]
+ # mask = src_length[..., None, None] > 1
+ # text_emb = text_emb * mask
+
+ # attention_mask[src_length < 1] = attention_mask[src_length < 1] * 0.0 + 1.0
+ # return text_emb, attention_mask
+
+ def forward(self, phoneme_idx):
+ if self.device is None:
+ self.device = self.learnable_positional_embedding.device
+ self.pad_token_sequence = self.pad_token_sequence.to(self.device)
+
+ src_length = self._get_src_length(phoneme_idx)
+ text_emb, m, logs, text_emb_mask = self.text_encoder(phoneme_idx, src_length)
+ text_emb = text_emb + self.learnable_positional_embedding
+
+ # text_emb, text_emb_mask = self.make_empty_condition_unconditional(src_length, text_emb, text_emb_mask)
+
+ return [
+ text_emb.permute(0, 2, 1),
+ text_emb_mask.squeeze(1),
+ ] # [2, 250, 192], [2, 250]
+
+
+class FlanT5HiddenState(nn.Module):
+ """
+ llama = FlanT5HiddenState()
+ data = ["","this is not an empty sentence"]
+ encoder_hidden_states = llama(data)
+ import ipdb;ipdb.set_trace()
+ """
+
+ def __init__(
+ self, text_encoder_name="google/flan-t5-large", freeze_text_encoder=True
+ ):
+ super().__init__()
+ self.freeze_text_encoder = freeze_text_encoder
+ self.tokenizer = AutoTokenizer.from_pretrained(text_encoder_name)
+ self.model = T5EncoderModel(T5Config.from_pretrained(text_encoder_name))
+ if freeze_text_encoder:
+ self.model.eval()
+ for p in self.model.parameters():
+ p.requires_grad = False
+ else:
+ print("=> The text encoder is learnable")
+
+ self.empty_hidden_state_cfg = None
+ self.device = None
+
+ # Required
+ def get_unconditional_condition(self, batchsize):
+ param = next(self.model.parameters())
+ if self.freeze_text_encoder:
+ assert param.requires_grad == False
+
+ # device = param.device
+ if self.empty_hidden_state_cfg is None:
+ self.empty_hidden_state_cfg, _ = self([""])
+
+ hidden_state = torch.cat([self.empty_hidden_state_cfg] * batchsize).float()
+ attention_mask = (
+ torch.ones((batchsize, hidden_state.size(1)))
+ .to(hidden_state.device)
+ .float()
+ )
+ return [hidden_state, attention_mask] # Need to return float type
+
+ def forward(self, batch):
+ param = next(self.model.parameters())
+ if self.freeze_text_encoder:
+ assert param.requires_grad == False
+
+ if self.device is None:
+ self.device = param.device
+
+ # print("Manually change text")
+ # for i in range(len(batch)):
+ # batch[i] = "dog barking"
+ try:
+ return self.encode_text(batch)
+ except Exception as e:
+ print(e, batch)
+ logging.exception("An error occurred: %s", str(e))
+
+ def encode_text(self, prompt):
+ device = self.model.device
+ batch = self.tokenizer(
+ prompt,
+ max_length=128, # self.tokenizer.model_max_length
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ )
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
+ device
+ )
+ # Get text encoding
+ if self.freeze_text_encoder:
+ with torch.no_grad():
+ encoder_hidden_states = self.model(
+ input_ids=input_ids, attention_mask=attention_mask
+ )[0]
+ else:
+ encoder_hidden_states = self.model(
+ input_ids=input_ids, attention_mask=attention_mask
+ )[0]
+ return [
+ encoder_hidden_states.detach(),
+ attention_mask.float(),
+ ] # Attention mask == 1 means usable token
+
+
+class SequenceGenAudioMAECond(Sequence2AudioMAE):
+ def __init__(
+ self,
+ cond_stage_config,
+ base_learning_rate,
+ sequence_gen_length,
+ sequence_input_key,
+ sequence_input_embed_dim,
+ batchsize,
+ always_output_audiomae_gt=False,
+ pretrained_path=None,
+ force_reload_pretrain_avoid_overwrite=False,
+ learnable=True,
+ use_warmup=True,
+ device=None,
+ use_gt_mae_output=None, # False: does not use AudioMAE GT, True: Use AudioMAE GT
+ use_gt_mae_prob=None,
+ ): # The prob of using AudioMAE GT
+ if use_warmup:
+ use_warmup = False
+
+ super().__init__(
+ base_learning_rate=base_learning_rate,
+ cond_stage_config=cond_stage_config,
+ sequence_gen_length=sequence_gen_length,
+ sequence_input_key=sequence_input_key,
+ use_warmup=use_warmup,
+ sequence_input_embed_dim=sequence_input_embed_dim,
+ batchsize=batchsize,
+ )
+
+ assert use_gt_mae_output is not None and use_gt_mae_prob is not None
+ self.always_output_audiomae_gt = always_output_audiomae_gt
+ self.force_reload_pretrain_avoid_overwrite = (
+ force_reload_pretrain_avoid_overwrite
+ )
+ self.pretrained_path = pretrained_path
+ self.device = device
+ if self.force_reload_pretrain_avoid_overwrite:
+ self.is_reload = False
+ else:
+ self.is_reload = True
+
+ self.load_pretrain_model()
+
+ self.use_gt_mae_output = use_gt_mae_output
+ self.use_gt_mae_prob = use_gt_mae_prob
+ self.learnable = learnable
+
+ if not learnable:
+ # Only optimize the GPT2 model
+ for p in self.model.parameters():
+ p.requires_grad = False
+ self.eval()
+
+ def load_pretrain_model(self):
+ if self.pretrained_path is not None:
+ print("Reload SequenceGenAudioMAECond from %s" % self.pretrained_path)
+ state_dict = torch.load(self.pretrained_path)["state_dict"]
+ self.load_state_dict(state_dict)
+
+ # Required
+ def get_unconditional_condition(self, batchsize):
+ return_dict = self.cfg_uncond(batchsize)
+ return_dict["crossattn_audiomae_generated"] = [
+ return_dict["crossattn_audiomae_pooled"][0],
+ torch.ones_like(return_dict["crossattn_audiomae_pooled"][1]).float(),
+ ]
+ return return_dict
+
+ def forward(self, batch):
+ # The conditional module can return both tensor or dictionaries
+ # The returned tensor will be corresponding to the cond_stage_key
+ # The returned dict will have keys that correspond to the cond_stage_key
+ ret_dict = {}
+
+ if self.force_reload_pretrain_avoid_overwrite and not self.is_reload:
+ self.load_pretrain_model()
+ self.is_reload = True
+
+ # if(self.always_output_audiomae_gt or (self.use_gt_mae_output and torch.rand(1).item() < self.use_gt_mae_prob)):
+ # cond_dict = self.get_input(batch)
+ # ret_dict["crossattn_audiomae_generated"] = [cond_dict["crossattn_audiomae_pooled"][0], torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float()] # Input sequence and mask
+ # else:
+ input_embeds, cond_dict = self.generate(batch)
+ input_embeds_mask = (
+ torch.ones((input_embeds.size(0), input_embeds.size(1)))
+ .to(input_embeds.device)
+ .float()
+ )
+ ret_dict["crossattn_audiomae_generated"] = [
+ input_embeds,
+ input_embeds_mask,
+ ] # Input sequence and mask
+
+ # If the following two keys are not in cond_stage_key, then they will not be used as condition
+ for key in cond_dict.keys():
+ ret_dict[key] = cond_dict[key]
+
+ return ret_dict
+
+
+class AudioMAEConditionCTPoolRandTFSeparated(nn.Module):
+ """
+ audiomae = AudioMAEConditionCTPool2x2()
+ data = torch.randn((4, 1024, 128))
+ output = audiomae(data)
+ import ipdb;ipdb.set_trace()
+ exit(0)
+ """
+
+ def __init__(
+ self,
+ time_pooling_factors=[1, 2, 4, 8],
+ freq_pooling_factors=[1, 2, 4, 8],
+ eval_time_pooling=None,
+ eval_freq_pooling=None,
+ mask_ratio=0.0,
+ regularization=False,
+ no_audiomae_mask=True,
+ no_audiomae_average=False,
+ ):
+ super().__init__()
+ self.device = None
+ self.time_pooling_factors = time_pooling_factors
+ self.freq_pooling_factors = freq_pooling_factors
+ self.no_audiomae_mask = no_audiomae_mask
+ self.no_audiomae_average = no_audiomae_average
+
+ self.eval_freq_pooling = eval_freq_pooling
+ self.eval_time_pooling = eval_time_pooling
+ self.mask_ratio = mask_ratio
+ self.use_reg = regularization
+
+ self.audiomae = Vanilla_AudioMAE()
+ self.audiomae.eval()
+ for p in self.audiomae.parameters():
+ p.requires_grad = False
+
+ # Required
+ def get_unconditional_condition(self, batchsize):
+ param = next(self.audiomae.parameters())
+ assert param.requires_grad == False
+ device = param.device
+ # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors)
+ time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
+ self.eval_freq_pooling, 8
+ )
+ # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))]
+ # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
+ token_num = int(512 / (time_pool * freq_pool))
+ return [
+ torch.zeros((batchsize, token_num, 768)).to(device).float(),
+ torch.ones((batchsize, token_num)).to(device).float(),
+ ]
+
+ def pool(self, representation, time_pool=None, freq_pool=None):
+ assert representation.size(-1) == 768
+ representation = representation[:, 1:, :].transpose(1, 2)
+ bs, embedding_dim, token_num = representation.size()
+ representation = representation.reshape(bs, embedding_dim, 64, 8)
+
+ if self.training:
+ if time_pool is None and freq_pool is None:
+ time_pool = min(
+ 64,
+ self.time_pooling_factors[
+ np.random.choice(list(range(len(self.time_pooling_factors))))
+ ],
+ )
+ freq_pool = min(
+ 8,
+ self.freq_pooling_factors[
+ np.random.choice(list(range(len(self.freq_pooling_factors))))
+ ],
+ )
+ # freq_pool = min(8, time_pool) # TODO here I make some modification.
+ else:
+ time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
+ self.eval_freq_pooling, 8
+ )
+
+ self.avgpooling = nn.AvgPool2d(
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
+ )
+ self.maxpooling = nn.MaxPool2d(
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
+ )
+
+ pooled = (
+ self.avgpooling(representation) + self.maxpooling(representation)
+ ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num]
+ pooled = pooled.flatten(2).transpose(1, 2)
+ return pooled # [bs, token_num, embedding_dim]
+
+ def regularization(self, x):
+ assert x.size(-1) == 768
+ x = F.normalize(x, p=2, dim=-1)
+ return x
+
+ # Required
+ def forward(self, batch, time_pool=None, freq_pool=None):
+ assert batch.size(-2) == 1024 and batch.size(-1) == 128
+
+ if self.device is None:
+ self.device = batch.device
+
+ batch = batch.unsqueeze(1)
+ with torch.no_grad():
+ representation = self.audiomae(
+ batch,
+ mask_ratio=self.mask_ratio,
+ no_mask=self.no_audiomae_mask,
+ no_average=self.no_audiomae_average,
+ )
+ representation = self.pool(representation, time_pool, freq_pool)
+ if self.use_reg:
+ representation = self.regularization(representation)
+ return [
+ representation,
+ torch.ones((representation.size(0), representation.size(1)))
+ .to(representation.device)
+ .float(),
+ ]
+
+
+class AudioMAEConditionCTPoolRand(nn.Module):
+ """
+ audiomae = AudioMAEConditionCTPool2x2()
+ data = torch.randn((4, 1024, 128))
+ output = audiomae(data)
+ import ipdb;ipdb.set_trace()
+ exit(0)
+ """
+
+ def __init__(
+ self,
+ time_pooling_factors=[1, 2, 4, 8],
+ freq_pooling_factors=[1, 2, 4, 8],
+ eval_time_pooling=None,
+ eval_freq_pooling=None,
+ mask_ratio=0.0,
+ regularization=False,
+ no_audiomae_mask=True,
+ no_audiomae_average=False,
+ ):
+ super().__init__()
+ self.device = None
+ self.time_pooling_factors = time_pooling_factors
+ self.freq_pooling_factors = freq_pooling_factors
+ self.no_audiomae_mask = no_audiomae_mask
+ self.no_audiomae_average = no_audiomae_average
+
+ self.eval_freq_pooling = eval_freq_pooling
+ self.eval_time_pooling = eval_time_pooling
+ self.mask_ratio = mask_ratio
+ self.use_reg = regularization
+
+ self.audiomae = Vanilla_AudioMAE()
+ self.audiomae.eval()
+ for p in self.audiomae.parameters():
+ p.requires_grad = False
+
+ # Required
+ def get_unconditional_condition(self, batchsize):
+ param = next(self.audiomae.parameters())
+ assert param.requires_grad == False
+ device = param.device
+ # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors)
+ time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
+ self.eval_freq_pooling, 8
+ )
+ # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))]
+ # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
+ token_num = int(512 / (time_pool * freq_pool))
+ return [
+ torch.zeros((batchsize, token_num, 768)).to(device).float(),
+ torch.ones((batchsize, token_num)).to(device).float(),
+ ]
+
+ def pool(self, representation, time_pool=None, freq_pool=None):
+ assert representation.size(-1) == 768
+ representation = representation[:, 1:, :].transpose(1, 2)
+ bs, embedding_dim, token_num = representation.size()
+ representation = representation.reshape(bs, embedding_dim, 64, 8)
+
+ if self.training:
+ if time_pool is None and freq_pool is None:
+ time_pool = min(
+ 64,
+ self.time_pooling_factors[
+ np.random.choice(list(range(len(self.time_pooling_factors))))
+ ],
+ )
+ # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
+ freq_pool = min(8, time_pool) # TODO here I make some modification.
+ else:
+ time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
+ self.eval_freq_pooling, 8
+ )
+
+ self.avgpooling = nn.AvgPool2d(
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
+ )
+ self.maxpooling = nn.MaxPool2d(
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
+ )
+
+ pooled = (
+ self.avgpooling(representation) + self.maxpooling(representation)
+ ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num]
+ pooled = pooled.flatten(2).transpose(1, 2)
+ return pooled # [bs, token_num, embedding_dim]
+
+ def regularization(self, x):
+ assert x.size(-1) == 768
+ x = F.normalize(x, p=2, dim=-1)
+ return x
+
+ # Required
+ def forward(self, batch, time_pool=None, freq_pool=None):
+ assert batch.size(-2) == 1024 and batch.size(-1) == 128
+
+ if self.device is None:
+ self.device = batch.device
+
+ batch = batch.unsqueeze(1)
+ with torch.no_grad():
+ representation = self.audiomae(
+ batch,
+ mask_ratio=self.mask_ratio,
+ no_mask=self.no_audiomae_mask,
+ no_average=self.no_audiomae_average,
+ )
+ representation = self.pool(representation, time_pool, freq_pool)
+ if self.use_reg:
+ representation = self.regularization(representation)
+ return [
+ representation,
+ torch.ones((representation.size(0), representation.size(1)))
+ .to(representation.device)
+ .float(),
+ ]
+
+
+class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
+ def __init__(
+ self,
+ pretrained_path="",
+ sampling_rate=16000,
+ embed_mode="audio",
+ amodel="HTSAT-base",
+ unconditional_prob=0.1,
+ random_mute=False,
+ max_random_mute_portion=0.5,
+ training_mode=True,
+ ):
+ super().__init__()
+ self.device = "cpu"
+ self.precision = "fp32"
+ self.amodel = amodel # or 'PANN-14'
+ self.tmodel = "roberta" # the best text encoder in our training
+ self.enable_fusion = False # False if you do not want to use the fusion model
+ self.fusion_type = "aff_2d"
+ self.pretrained = pretrained_path
+ self.embed_mode = embed_mode
+ self.embed_mode_orig = embed_mode
+ self.sampling_rate = sampling_rate
+ self.unconditional_prob = unconditional_prob
+ self.random_mute = random_mute
+ self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
+ self.max_random_mute_portion = max_random_mute_portion
+ self.training_mode = training_mode
+ self.model, self.model_cfg = create_model(
+ self.amodel,
+ self.tmodel,
+ self.pretrained,
+ precision=self.precision,
+ device=self.device,
+ enable_fusion=self.enable_fusion,
+ fusion_type=self.fusion_type,
+ )
+ audio_cfg = self.model_cfg["audio_cfg"]
+ self.mel_transform = torchaudio.transforms.MelSpectrogram(
+ sample_rate=audio_cfg["sample_rate"],
+ n_fft=audio_cfg["window_size"],
+ win_length=audio_cfg["window_size"],
+ hop_length=audio_cfg["hop_size"],
+ center=True,
+ pad_mode="reflect",
+ power=2.0,
+ norm=None,
+ onesided=True,
+ n_mels=64,
+ f_min=audio_cfg["fmin"],
+ f_max=audio_cfg["fmax"],
+ )
+ for p in self.model.parameters():
+ p.requires_grad = False
+ self.unconditional_token = None
+ self.model.eval()
+
+ def get_unconditional_condition(self, batchsize):
+ self.unconditional_token = self.model.get_text_embedding(
+ self.tokenizer(["", ""])
+ )[0:1]
+ return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
+
+ def batch_to_list(self, batch):
+ ret = []
+ for i in range(batch.size(0)):
+ ret.append(batch[i])
+ return ret
+
+ def make_decision(self, probability):
+ if float(torch.rand(1)) < probability:
+ return True
+ else:
+ return False
+
+ def random_uniform(self, start, end):
+ val = torch.rand(1).item()
+ return start + (end - start) * val
+
+ def _random_mute(self, waveform):
+ # waveform: [bs, t-steps]
+ t_steps = waveform.size(-1)
+ for i in range(waveform.size(0)):
+ mute_size = int(
+ self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
+ )
+ mute_start = int(self.random_uniform(0, t_steps - mute_size))
+ waveform[i, mute_start : mute_start + mute_size] = 0
+ return waveform
+
+ def cos_similarity(self, waveform, text):
+ # waveform: [bs, t_steps]
+ original_embed_mode = self.embed_mode
+ with torch.no_grad():
+ self.embed_mode = "audio"
+ audio_emb = self(waveform.cuda())
+ self.embed_mode = "text"
+ text_emb = self(text)
+ similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
+ self.embed_mode = original_embed_mode
+ return similarity.squeeze()
+
+ def build_unconditional_emb(self):
+ self.unconditional_token = self.model.get_text_embedding(
+ self.tokenizer(["", ""])
+ )[0:1]
+
+ def forward(self, batch):
+ # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0
+ # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0
+ if self.model.training == True and not self.training_mode:
+ print(
+ "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters."
+ )
+ self.model, self.model_cfg = create_model(
+ self.amodel,
+ self.tmodel,
+ self.pretrained,
+ precision=self.precision,
+ device="cuda",
+ enable_fusion=self.enable_fusion,
+ fusion_type=self.fusion_type,
+ )
+ for p in self.model.parameters():
+ p.requires_grad = False
+ self.model.eval()
+
+ if self.unconditional_token is None:
+ self.build_unconditional_emb()
+
+ # if(self.training_mode):
+ # assert self.model.training == True
+ # else:
+ # assert self.model.training == False
+
+ # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
+ if self.embed_mode == "audio":
+ if not self.training:
+ print("INFO: clap model calculate the audio embedding as condition")
+ with torch.no_grad():
+ # assert (
+ # self.sampling_rate == 16000
+ # ), "We only support 16000 sampling rate"
+
+ # if self.random_mute:
+ # batch = self._random_mute(batch)
+ # batch: [bs, 1, t-samples]
+ if self.sampling_rate != 48000:
+ batch = torchaudio.functional.resample(
+ batch, orig_freq=self.sampling_rate, new_freq=48000
+ )
+
+ audio_data = batch.squeeze(1)
+ mel = self.mel_transform(audio_data)
+ audio_dict = get_audio_features(
+ audio_data,
+ mel,
+ 480000,
+ data_truncating="fusion",
+ data_filling="repeatpad",
+ audio_cfg=self.model_cfg["audio_cfg"],
+ )
+ # [bs, 512]
+ embed = self.model.get_audio_embedding(audio_dict)
+ elif self.embed_mode == "text":
+ with torch.no_grad():
+ # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
+ text_data = self.tokenizer(batch)
+
+ if isinstance(batch, str) or (
+ isinstance(batch, list) and len(batch) == 1
+ ):
+ for key in text_data.keys():
+ text_data[key] = text_data[key].unsqueeze(0)
+
+ embed = self.model.get_text_embedding(text_data)
+
+ embed = embed.unsqueeze(1)
+ for i in range(embed.size(0)):
+ if self.make_decision(self.unconditional_prob):
+ embed[i] = self.unconditional_token
+ # embed = torch.randn((batch.size(0), 1, 512)).type_as(batch)
+ return embed.detach()
+
+ def tokenizer(self, text):
+ result = self.tokenize(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=512,
+ return_tensors="pt",
+ )
+ return {k: v.squeeze(0) for k, v in result.items()}
diff --git a/audioldm2/latent_diffusion/modules/phoneme_encoder/__init__.py b/audioldm2/latent_diffusion/modules/phoneme_encoder/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/latent_diffusion/modules/phoneme_encoder/attentions.py b/audioldm2/latent_diffusion/modules/phoneme_encoder/attentions.py
new file mode 100755
index 0000000000000000000000000000000000000000..3553a688d41b07a45a7ced25f740a55dbc0b6d94
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/phoneme_encoder/attentions.py
@@ -0,0 +1,430 @@
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+import audioldm2.latent_diffusion.modules.phoneme_encoder.commons as commons
+
+LRELU_SLOPE = 0.1
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-5):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ x = x.transpose(1, -1)
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+ return x.transpose(1, -1)
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size=1,
+ p_dropout=0.0,
+ window_size=4,
+ **kwargs
+ ):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+
+ self.drop = nn.Dropout(p_dropout)
+ self.attn_layers = nn.ModuleList()
+ self.norm_layers_1 = nn.ModuleList()
+ self.ffn_layers = nn.ModuleList()
+ self.norm_layers_2 = nn.ModuleList()
+ for i in range(self.n_layers):
+ self.attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels,
+ hidden_channels,
+ n_heads,
+ p_dropout=p_dropout,
+ window_size=window_size,
+ )
+ )
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(
+ hidden_channels,
+ hidden_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=p_dropout,
+ )
+ )
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+ def forward(self, x, x_mask):
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+ x = x * x_mask
+ for i in range(self.n_layers):
+ y = self.attn_layers[i](x, x, attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_1[i](x + y)
+
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = self.norm_layers_2[i](x + y)
+ x = x * x_mask
+ return x
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size=1,
+ p_dropout=0.0,
+ proximal_bias=False,
+ proximal_init=True,
+ **kwargs
+ ):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.proximal_bias = proximal_bias
+ self.proximal_init = proximal_init
+
+ self.drop = nn.Dropout(p_dropout)
+ self.self_attn_layers = nn.ModuleList()
+ self.norm_layers_0 = nn.ModuleList()
+ self.encdec_attn_layers = nn.ModuleList()
+ self.norm_layers_1 = nn.ModuleList()
+ self.ffn_layers = nn.ModuleList()
+ self.norm_layers_2 = nn.ModuleList()
+ for i in range(self.n_layers):
+ self.self_attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels,
+ hidden_channels,
+ n_heads,
+ p_dropout=p_dropout,
+ proximal_bias=proximal_bias,
+ proximal_init=proximal_init,
+ )
+ )
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
+ self.encdec_attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
+ )
+ )
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(
+ hidden_channels,
+ hidden_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=p_dropout,
+ causal=True,
+ )
+ )
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+ def forward(self, x, x_mask, h, h_mask):
+ """
+ x: decoder input
+ h: encoder output
+ """
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
+ device=x.device, dtype=x.dtype
+ )
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+ x = x * x_mask
+ for i in range(self.n_layers):
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_0[i](x + y)
+
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_1[i](x + y)
+
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = self.norm_layers_2[i](x + y)
+ x = x * x_mask
+ return x
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(
+ self,
+ channels,
+ out_channels,
+ n_heads,
+ p_dropout=0.0,
+ window_size=None,
+ heads_share=True,
+ block_length=None,
+ proximal_bias=False,
+ proximal_init=False,
+ ):
+ super().__init__()
+ assert channels % n_heads == 0
+
+ self.channels = channels
+ self.out_channels = out_channels
+ self.n_heads = n_heads
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+ self.heads_share = heads_share
+ self.block_length = block_length
+ self.proximal_bias = proximal_bias
+ self.proximal_init = proximal_init
+ self.attn = None
+
+ self.k_channels = channels // n_heads
+ self.conv_q = nn.Conv1d(channels, channels, 1)
+ self.conv_k = nn.Conv1d(channels, channels, 1)
+ self.conv_v = nn.Conv1d(channels, channels, 1)
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
+ self.drop = nn.Dropout(p_dropout)
+
+ if window_size is not None:
+ n_heads_rel = 1 if heads_share else n_heads
+ rel_stddev = self.k_channels**-0.5
+ self.emb_rel_k = nn.Parameter(
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+ * rel_stddev
+ )
+ self.emb_rel_v = nn.Parameter(
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+ * rel_stddev
+ )
+
+ nn.init.xavier_uniform_(self.conv_q.weight)
+ nn.init.xavier_uniform_(self.conv_k.weight)
+ nn.init.xavier_uniform_(self.conv_v.weight)
+ if proximal_init:
+ with torch.no_grad():
+ self.conv_k.weight.copy_(self.conv_q.weight)
+ self.conv_k.bias.copy_(self.conv_q.bias)
+
+ def forward(self, x, c, attn_mask=None):
+ q = self.conv_q(x)
+ k = self.conv_k(c)
+ v = self.conv_v(c)
+
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+ x = self.conv_o(x)
+ return x
+
+ def attention(self, query, key, value, mask=None):
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
+ b, d, t_s, t_t = (*key.size(), query.size(2))
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
+ if self.window_size is not None:
+ assert (
+ t_s == t_t
+ ), "Relative attention is only available for self-attention."
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+ rel_logits = self._matmul_with_relative_keys(
+ query / math.sqrt(self.k_channels), key_relative_embeddings
+ )
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
+ scores = scores + scores_local
+ if self.proximal_bias:
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
+ scores = scores + self._attention_bias_proximal(t_s).to(
+ device=scores.device, dtype=scores.dtype
+ )
+ if mask is not None:
+ scores = scores.masked_fill(mask == 0, -1e4)
+ if self.block_length is not None:
+ assert (
+ t_s == t_t
+ ), "Local attention is only available for self-attention."
+ block_mask = (
+ torch.ones_like(scores)
+ .triu(-self.block_length)
+ .tril(self.block_length)
+ )
+ scores = scores.masked_fill(block_mask == 0, -1e4)
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
+ p_attn = self.drop(p_attn)
+ output = torch.matmul(p_attn, value)
+ if self.window_size is not None:
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
+ value_relative_embeddings = self._get_relative_embeddings(
+ self.emb_rel_v, t_s
+ )
+ output = output + self._matmul_with_relative_values(
+ relative_weights, value_relative_embeddings
+ )
+ output = (
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
+ return output, p_attn
+
+ def _matmul_with_relative_values(self, x, y):
+ """
+ x: [b, h, l, m]
+ y: [h or 1, m, d]
+ ret: [b, h, l, d]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0))
+ return ret
+
+ def _matmul_with_relative_keys(self, x, y):
+ """
+ x: [b, h, l, d]
+ y: [h or 1, m, d]
+ ret: [b, h, l, m]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
+ return ret
+
+ def _get_relative_embeddings(self, relative_embeddings, length):
+ 2 * self.window_size + 1
+ # Pad first before slice to avoid using cond ops.
+ pad_length = max(length - (self.window_size + 1), 0)
+ slice_start_position = max((self.window_size + 1) - length, 0)
+ slice_end_position = slice_start_position + 2 * length - 1
+ if pad_length > 0:
+ padded_relative_embeddings = F.pad(
+ relative_embeddings,
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
+ )
+ else:
+ padded_relative_embeddings = relative_embeddings
+ used_relative_embeddings = padded_relative_embeddings[
+ :, slice_start_position:slice_end_position
+ ]
+ return used_relative_embeddings
+
+ def _relative_position_to_absolute_position(self, x):
+ """
+ x: [b, h, l, 2*l-1]
+ ret: [b, h, l, l]
+ """
+ batch, heads, length, _ = x.size()
+ # Concat columns of pad to shift from relative to absolute indexing.
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
+
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
+ x_flat = x.view([batch, heads, length * 2 * length])
+ x_flat = F.pad(
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
+ )
+
+ # Reshape and slice out the padded elements.
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
+ :, :, :length, length - 1 :
+ ]
+ return x_final
+
+ def _absolute_position_to_relative_position(self, x):
+ """
+ x: [b, h, l, l]
+ ret: [b, h, l, 2*l-1]
+ """
+ batch, heads, length, _ = x.size()
+ # padd along column
+ x = F.pad(
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
+ )
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
+ # add 0's in the beginning that will skew the elements after reshape
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+ return x_final
+
+ def _attention_bias_proximal(self, length):
+ """Bias for self-attention to encourage attention to close positions.
+ Args:
+ length: an integer scalar.
+ Returns:
+ a Tensor with shape [1, 1, length, length]
+ """
+ r = torch.arange(length, dtype=torch.float32)
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=0.0,
+ activation=None,
+ causal=False,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.activation = activation
+ self.causal = causal
+
+ if causal:
+ self.padding = self._causal_padding
+ else:
+ self.padding = self._same_padding
+
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
+ self.drop = nn.Dropout(p_dropout)
+
+ def forward(self, x, x_mask):
+ x = self.conv_1(self.padding(x * x_mask))
+ if self.activation == "gelu":
+ x = x * torch.sigmoid(1.702 * x)
+ else:
+ x = torch.relu(x)
+ x = self.drop(x)
+ x = self.conv_2(self.padding(x * x_mask))
+ return x * x_mask
+
+ def _causal_padding(self, x):
+ if self.kernel_size == 1:
+ return x
+ pad_l = self.kernel_size - 1
+ pad_r = 0
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+ x = F.pad(x, commons.convert_pad_shape(padding))
+ return x
+
+ def _same_padding(self, x):
+ if self.kernel_size == 1:
+ return x
+ pad_l = (self.kernel_size - 1) // 2
+ pad_r = self.kernel_size // 2
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+ x = F.pad(x, commons.convert_pad_shape(padding))
+ return x
diff --git a/audioldm2/latent_diffusion/modules/phoneme_encoder/commons.py b/audioldm2/latent_diffusion/modules/phoneme_encoder/commons.py
new file mode 100755
index 0000000000000000000000000000000000000000..9515724c12ab2f856b9a2ec14e38cc63df9b85d6
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/phoneme_encoder/commons.py
@@ -0,0 +1,161 @@
+import math
+import torch
+from torch.nn import functional as F
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def intersperse(lst, item):
+ result = [item] * (len(lst) * 2 + 1)
+ result[1::2] = lst
+ return result
+
+
+def kl_divergence(m_p, logs_p, m_q, logs_q):
+ """KL(P||Q)"""
+ kl = (logs_q - logs_p) - 0.5
+ kl += (
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
+ )
+ return kl
+
+
+def rand_gumbel(shape):
+ """Sample from the Gumbel distribution, protect from overflows."""
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
+ return -torch.log(-torch.log(uniform_samples))
+
+
+def rand_gumbel_like(x):
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
+ return g
+
+
+def slice_segments(x, ids_str, segment_size=4):
+ ret = torch.zeros_like(x[:, :, :segment_size])
+ for i in range(x.size(0)):
+ idx_str = ids_str[i]
+ idx_end = idx_str + segment_size
+ ret[i] = x[i, :, idx_str:idx_end]
+ return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = x_lengths - segment_size + 1
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+ ret = slice_segments(x, ids_str, segment_size)
+ return ret, ids_str
+
+
+def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
+ position = torch.arange(length, dtype=torch.float)
+ num_timescales = channels // 2
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
+ num_timescales - 1
+ )
+ inv_timescales = min_timescale * torch.exp(
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
+ )
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
+ signal = signal.view(1, channels, length)
+ return signal
+
+
+def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
+ b, channels, length = x.size()
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+ return x + signal.to(dtype=x.dtype, device=x.device)
+
+
+def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
+ b, channels, length = x.size()
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
+
+
+def subsequent_mask(length):
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
+ return mask
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a + input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def shift_1d(x):
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
+ return x
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def generate_path(duration, mask):
+ """
+ duration: [b, 1, t_x]
+ mask: [b, 1, t_y, t_x]
+ """
+ duration.device
+
+ b, _, t_y, t_x = mask.shape
+ cum_duration = torch.cumsum(duration, -1)
+
+ cum_duration_flat = cum_duration.view(b * t_x)
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
+ path = path.view(b, t_x, t_y)
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
+ path = path.unsqueeze(1).transpose(2, 3) * mask
+ return path
+
+
+def clip_grad_value_(parameters, clip_value, norm_type=2):
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
+ norm_type = float(norm_type)
+ if clip_value is not None:
+ clip_value = float(clip_value)
+
+ total_norm = 0
+ for p in parameters:
+ param_norm = p.grad.data.norm(norm_type)
+ total_norm += param_norm.item() ** norm_type
+ if clip_value is not None:
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
+ total_norm = total_norm ** (1.0 / norm_type)
+ return total_norm
diff --git a/audioldm2/latent_diffusion/modules/phoneme_encoder/encoder.py b/audioldm2/latent_diffusion/modules/phoneme_encoder/encoder.py
new file mode 100755
index 0000000000000000000000000000000000000000..b39bf583b5ea88a4771181e491c8deb92b2d7559
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/phoneme_encoder/encoder.py
@@ -0,0 +1,50 @@
+import math
+import torch
+from torch import nn
+
+import audioldm2.latent_diffusion.modules.phoneme_encoder.commons as commons
+import audioldm2.latent_diffusion.modules.phoneme_encoder.attentions as attentions
+
+
+class TextEncoder(nn.Module):
+ def __init__(
+ self,
+ n_vocab,
+ out_channels=192,
+ hidden_channels=192,
+ filter_channels=768,
+ n_heads=2,
+ n_layers=6,
+ kernel_size=3,
+ p_dropout=0.1,
+ ):
+ super().__init__()
+ self.n_vocab = n_vocab
+ self.out_channels = out_channels
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
+
+ self.encoder = attentions.Encoder(
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
+ )
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+ def forward(self, x, x_lengths):
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
+ x = torch.transpose(x, 1, -1) # [b, h, t]
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
+ x.dtype
+ )
+
+ x = self.encoder(x * x_mask, x_mask)
+ stats = self.proj(x) * x_mask
+
+ m, logs = torch.split(stats, self.out_channels, dim=1)
+ return x, m, logs, x_mask
diff --git a/audioldm2/latent_diffusion/util.py b/audioldm2/latent_diffusion/util.py
new file mode 100755
index 0000000000000000000000000000000000000000..3dd301b1c0a39a5b905aa23f4b98d224df7d87d9
--- /dev/null
+++ b/audioldm2/latent_diffusion/util.py
@@ -0,0 +1,217 @@
+import importlib
+
+import torch
+import numpy as np
+from collections import abc
+
+import multiprocessing as mp
+from threading import Thread
+from queue import Queue
+
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
+ nc = int(40 * (wh[0] / 256))
+ lines = "\n".join(
+ xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
+ )
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def int16_to_float32(x):
+ return (x / 32767.0).astype(np.float32)
+
+
+def float32_to_int16(x):
+ x = np.clip(x, a_min=-1.0, a_max=1.0)
+ return (x * 32767.0).astype(np.int16)
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == "__is_first_stage__":
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
+ # create dummy dataset instance
+
+ # run prefetching
+ if idx_to_fn:
+ res = func(data, worker_id=idx)
+ else:
+ res = func(data)
+ Q.put([idx, res])
+ Q.put("Done")
+
+
+def parallel_data_prefetch(
+ func: callable,
+ data,
+ n_proc,
+ target_data_type="ndarray",
+ cpu_intensive=True,
+ use_worker_id=False,
+):
+ # if target_data_type not in ["ndarray", "list"]:
+ # raise ValueError(
+ # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
+ # )
+ if isinstance(data, np.ndarray) and target_data_type == "list":
+ raise ValueError("list expected but function got ndarray.")
+ elif isinstance(data, abc.Iterable):
+ if isinstance(data, dict):
+ print(
+ f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
+ )
+ data = list(data.values())
+ if target_data_type == "ndarray":
+ data = np.asarray(data)
+ else:
+ data = list(data)
+ else:
+ raise TypeError(
+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
+ )
+
+ if cpu_intensive:
+ Q = mp.Queue(1000)
+ proc = mp.Process
+ else:
+ Q = Queue(1000)
+ proc = Thread
+ # spawn processes
+ if target_data_type == "ndarray":
+ arguments = [
+ [func, Q, part, i, use_worker_id]
+ for i, part in enumerate(np.array_split(data, n_proc))
+ ]
+ else:
+ step = (
+ int(len(data) / n_proc + 1)
+ if len(data) % n_proc != 0
+ else int(len(data) / n_proc)
+ )
+ arguments = [
+ [func, Q, part, i, use_worker_id]
+ for i, part in enumerate(
+ [data[i : i + step] for i in range(0, len(data), step)]
+ )
+ ]
+ processes = []
+ for i in range(n_proc):
+ p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
+ processes += [p]
+
+ # start processes
+ print(f"Start prefetching...")
+ import time
+
+ start = time.time()
+ gather_res = [[] for _ in range(n_proc)]
+ try:
+ for p in processes:
+ p.start()
+
+ k = 0
+ while k < n_proc:
+ # get result
+ res = Q.get()
+ if res == "Done":
+ k += 1
+ else:
+ gather_res[res[0]] = res[1]
+
+ except Exception as e:
+ print("Exception: ", e)
+ for p in processes:
+ p.terminate()
+
+ raise e
+ finally:
+ for p in processes:
+ p.join()
+ print(f"Prefetching complete. [{time.time() - start} sec.]")
+
+ if target_data_type == "ndarray":
+ if not isinstance(gather_res[0], np.ndarray):
+ return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
+
+ # order outputs
+ return np.concatenate(gather_res, axis=0)
+ elif target_data_type == "list":
+ out = []
+ for r in gather_res:
+ out.extend(r)
+ return out
+ else:
+ return gather_res
diff --git a/audioldm2/latent_encoder/__init__.py b/audioldm2/latent_encoder/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/latent_encoder/autoencoder.py b/audioldm2/latent_encoder/autoencoder.py
new file mode 100755
index 0000000000000000000000000000000000000000..f07075bb76a34edd8568797961752e4957129f92
--- /dev/null
+++ b/audioldm2/latent_encoder/autoencoder.py
@@ -0,0 +1,326 @@
+import torch
+import os
+
+import torch.nn.functional as F
+import numpy as np
+from audioldm2.latent_diffusion.modules.ema import *
+
+from audioldm2.latent_diffusion.modules.diffusionmodules.model import Encoder, Decoder
+from audioldm2.latent_diffusion.modules.distributions.distributions import (
+ DiagonalGaussianDistribution,
+)
+import soundfile as sf
+
+from audioldm2.utilities.model import get_vocoder
+from audioldm2.utilities.tools import synth_one_sample
+
+
+class AutoencoderKL(nn.Module):
+ def __init__(
+ self,
+ ddconfig=None,
+ lossconfig=None,
+ batchsize=None,
+ embed_dim=None,
+ time_shuffle=1,
+ subband=1,
+ sampling_rate=16000,
+ ckpt_path=None,
+ reload_from_ckpt=None,
+ ignore_keys=[],
+ image_key="fbank",
+ colorize_nlabels=None,
+ monitor=None,
+ base_learning_rate=1e-5,
+ ):
+ super().__init__()
+ self.automatic_optimization = False
+ assert (
+ "mel_bins" in ddconfig.keys()
+ ), "mel_bins is not specified in the Autoencoder config"
+ num_mel = ddconfig["mel_bins"]
+ self.image_key = image_key
+ self.sampling_rate = sampling_rate
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+
+ self.loss = None
+ self.subband = int(subband)
+
+ if self.subband > 1:
+ print("Use subband decomposition %s" % self.subband)
+
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+
+ if self.image_key == "fbank":
+ self.vocoder = get_vocoder(None, "cpu", num_mel)
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels) == int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.learning_rate = float(base_learning_rate)
+ # print("Initial learning rate %s" % self.learning_rate)
+
+ self.time_shuffle = time_shuffle
+ self.reload_from_ckpt = reload_from_ckpt
+ self.reloaded = False
+ self.mean, self.std = None, None
+
+ self.feature_cache = None
+ self.flag_first_run = True
+ self.train_step = 0
+
+ self.logger_save_dir = None
+ self.logger_exp_name = None
+
+ def get_log_dir(self):
+ if self.logger_save_dir is None and self.logger_exp_name is None:
+ return os.path.join(self.logger.save_dir, self.logger._project)
+ else:
+ return os.path.join(self.logger_save_dir, self.logger_exp_name)
+
+ def set_log_dir(self, save_dir, exp_name):
+ self.logger_save_dir = save_dir
+ self.logger_exp_name = exp_name
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ def encode(self, x):
+ # x = self.time_shuffle_operation(x)
+ # x = self.freq_split_subband(x)
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ # bs, ch, shuffled_timesteps, fbins = dec.size()
+ # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins)
+ # dec = self.freq_merge_subband(dec)
+ return dec
+
+ def decode_to_waveform(self, dec):
+ from audioldm2.utilities.model import vocoder_infer
+
+ if self.image_key == "fbank":
+ dec = dec.squeeze(1).permute(0, 2, 1)
+ wav_reconstruction = vocoder_infer(dec, self.vocoder)
+ elif self.image_key == "stft":
+ dec = dec.squeeze(1).permute(0, 2, 1)
+ wav_reconstruction = self.wave_decoder(dec)
+ return wav_reconstruction
+
+ def visualize_latent(self, input):
+ import matplotlib.pyplot as plt
+
+ # for i in range(10):
+ # zero_input = torch.zeros_like(input) - 11.59
+ # zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59
+
+ # posterior = self.encode(zero_input)
+ # latent = posterior.sample()
+ # avg_latent = torch.mean(latent, dim=1)[0]
+ # plt.imshow(avg_latent.cpu().detach().numpy().T)
+ # plt.savefig("%s.png" % i)
+ # plt.close()
+
+ np.save("input.npy", input.cpu().detach().numpy())
+ # zero_input = torch.zeros_like(input) - 11.59
+ time_input = input.clone()
+ time_input[:, :, :, :32] *= 0
+ time_input[:, :, :, :32] -= 11.59
+
+ np.save("time_input.npy", time_input.cpu().detach().numpy())
+
+ posterior = self.encode(time_input)
+ latent = posterior.sample()
+ np.save("time_latent.npy", latent.cpu().detach().numpy())
+ avg_latent = torch.mean(latent, dim=1)
+ for i in range(avg_latent.size(0)):
+ plt.imshow(avg_latent[i].cpu().detach().numpy().T)
+ plt.savefig("freq_%s.png" % i)
+ plt.close()
+
+ freq_input = input.clone()
+ freq_input[:, :, :512, :] *= 0
+ freq_input[:, :, :512, :] -= 11.59
+
+ np.save("freq_input.npy", freq_input.cpu().detach().numpy())
+
+ posterior = self.encode(freq_input)
+ latent = posterior.sample()
+ np.save("freq_latent.npy", latent.cpu().detach().numpy())
+ avg_latent = torch.mean(latent, dim=1)
+ for i in range(avg_latent.size(0)):
+ plt.imshow(avg_latent[i].cpu().detach().numpy().T)
+ plt.savefig("time_%s.png" % i)
+ plt.close()
+
+ def get_input(self, batch):
+ fname, text, label_indices, waveform, stft, fbank = (
+ batch["fname"],
+ batch["text"],
+ batch["label_vector"],
+ batch["waveform"],
+ batch["stft"],
+ batch["log_mel_spec"],
+ )
+ # if(self.time_shuffle != 1):
+ # if(fbank.size(1) % self.time_shuffle != 0):
+ # pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle)
+ # fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len))
+
+ ret = {}
+
+ ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = (
+ fbank.unsqueeze(1),
+ stft.unsqueeze(1),
+ fname,
+ waveform.unsqueeze(1),
+ )
+
+ return ret
+
+ def save_wave(self, batch_wav, fname, save_dir):
+ os.makedirs(save_dir, exist_ok=True)
+
+ for wav, name in zip(batch_wav, fname):
+ name = os.path.basename(name)
+
+ sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate)
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ @torch.no_grad()
+ def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs):
+ log = dict()
+ x = batch.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ log["samples"] = self.decode(posterior.sample())
+ log["reconstructions"] = xrec
+
+ log["inputs"] = x
+ wavs = self._log_img(log, train=train, index=0, waveform=waveform)
+ return wavs
+
+ def _log_img(self, log, train=True, index=0, waveform=None):
+ images_input = self.tensor2numpy(log["inputs"][index, 0]).T
+ images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T
+ images_samples = self.tensor2numpy(log["samples"][index, 0]).T
+
+ if train:
+ name = "train"
+ else:
+ name = "val"
+
+ if self.logger is not None:
+ self.logger.log_image(
+ "img_%s" % name,
+ [images_input, images_reconstruct, images_samples],
+ caption=["input", "reconstruct", "samples"],
+ )
+
+ inputs, reconstructions, samples = (
+ log["inputs"],
+ log["reconstructions"],
+ log["samples"],
+ )
+
+ if self.image_key == "fbank":
+ wav_original, wav_prediction = synth_one_sample(
+ inputs[index],
+ reconstructions[index],
+ labels="validation",
+ vocoder=self.vocoder,
+ )
+ wav_original, wav_samples = synth_one_sample(
+ inputs[index], samples[index], labels="validation", vocoder=self.vocoder
+ )
+ wav_original, wav_samples, wav_prediction = (
+ wav_original[0],
+ wav_samples[0],
+ wav_prediction[0],
+ )
+ elif self.image_key == "stft":
+ wav_prediction = (
+ self.decode_to_waveform(reconstructions)[index, 0]
+ .cpu()
+ .detach()
+ .numpy()
+ )
+ wav_samples = (
+ self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy()
+ )
+ wav_original = waveform[index, 0].cpu().detach().numpy()
+
+ if self.logger is not None:
+ self.logger.experiment.log(
+ {
+ "original_%s"
+ % name: wandb.Audio(
+ wav_original, caption="original", sample_rate=self.sampling_rate
+ ),
+ "reconstruct_%s"
+ % name: wandb.Audio(
+ wav_prediction,
+ caption="reconstruct",
+ sample_rate=self.sampling_rate,
+ ),
+ "samples_%s"
+ % name: wandb.Audio(
+ wav_samples, caption="samples", sample_rate=self.sampling_rate
+ ),
+ }
+ )
+
+ return wav_original, wav_prediction, wav_samples
+
+ def tensor2numpy(self, tensor):
+ return tensor.cpu().detach().numpy()
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
+ return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+ def __init__(self, *args, vq_interface=False, **kwargs):
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
+ super().__init__()
+
+ def encode(self, x, *args, **kwargs):
+ return x
+
+ def decode(self, x, *args, **kwargs):
+ return x
+
+ def quantize(self, x, *args, **kwargs):
+ if self.vq_interface:
+ return x, None, [None, None, None]
+ return x
+
+ def forward(self, x, *args, **kwargs):
+ return x
diff --git a/audioldm2/pipeline.py b/audioldm2/pipeline.py
new file mode 100755
index 0000000000000000000000000000000000000000..1eec55b0198049f8baf263c3b80a7a8a0584ebeb
--- /dev/null
+++ b/audioldm2/pipeline.py
@@ -0,0 +1,201 @@
+import os
+
+import yaml
+import torch
+import torchaudio
+
+from audioldm2.latent_diffusion.models.ddpm import LatentDiffusion
+from audioldm2.utils import default_audioldm_config, get_metadata, download_checkpoint
+from audioldm2.utilities.audio import read_wav_file
+import os
+
+CACHE_DIR = os.getenv(
+ "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm2")
+)
+
+
+def seed_everything(seed):
+ import random, os
+ import numpy as np
+ import torch
+
+ random.seed(seed)
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = True
+
+
+def text_to_filename(text):
+ return text.replace(" ", "_").replace("'", "_").replace('"', "_")
+
+
+def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec):
+ norm_mean = -4.2677393
+ norm_std = 4.5689974
+
+ if sampling_rate != 16000:
+ waveform_16k = torchaudio.functional.resample(
+ waveform, orig_freq=sampling_rate, new_freq=16000
+ )
+ else:
+ waveform_16k = waveform
+
+ waveform_16k = waveform_16k - waveform_16k.mean()
+ fbank = torchaudio.compliance.kaldi.fbank(
+ waveform_16k,
+ htk_compat=True,
+ sample_frequency=16000,
+ use_energy=False,
+ window_type="hanning",
+ num_mel_bins=128,
+ dither=0.0,
+ frame_shift=10,
+ )
+
+ TARGET_LEN = log_mel_spec.size(0)
+
+ # cut and pad
+ n_frames = fbank.shape[0]
+ p = TARGET_LEN - n_frames
+ if p > 0:
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
+ fbank = m(fbank)
+ elif p < 0:
+ fbank = fbank[:TARGET_LEN, :]
+
+ fbank = (fbank - norm_mean) / (norm_std * 2)
+
+ return {"ta_kaldi_fbank": fbank} # [1024, 128]
+
+
+def make_batch_for_text_to_audio(text, waveform=None, fbank=None, batchsize=1):
+ text = [text] * batchsize
+ if batchsize < 1:
+ print("Warning: Batchsize must be at least 1. Batchsize is set to .")
+
+ if fbank is None:
+ fbank = torch.zeros(
+ (batchsize, 1024, 64)
+ ) # Not used, here to keep the code format
+ else:
+ fbank = torch.FloatTensor(fbank)
+ fbank = fbank.expand(batchsize, 1024, 64)
+ assert fbank.size(0) == batchsize
+
+ stft = torch.zeros((batchsize, 1024, 512)) # Not used
+
+ if waveform is None:
+ waveform = torch.zeros((batchsize, 160000)) # Not used
+ ta_kaldi_fbank = torch.zeros((batchsize, 1024, 128))
+ else:
+ waveform = torch.FloatTensor(waveform)
+ waveform = waveform.expand(batchsize, -1)
+ assert waveform.size(0) == batchsize
+ ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, 16000, fbank)
+
+ batch = {
+ "text": text, # list
+ "fname": [text_to_filename(t) for t in text], # list
+ "waveform": waveform,
+ "stft": stft,
+ "log_mel_spec": fbank,
+ "ta_kaldi_fbank": ta_kaldi_fbank,
+ }
+
+ return batch
+
+
+def round_up_duration(duration):
+ return int(round(duration / 2.5) + 1) * 2.5
+
+
+def split_clap_weight_to_pth(checkpoint):
+ if os.path.exists(os.path.join(CACHE_DIR, "clap.pth")):
+ return
+ print("Constructing the weight for the CLAP model.")
+ include_keys = "cond_stage_models.0.cond_stage_models.0.model."
+ new_state_dict = {}
+ for each in checkpoint["state_dict"].keys():
+ if include_keys in each:
+ new_state_dict[each.replace(include_keys, "module.")] = checkpoint[
+ "state_dict"
+ ][each]
+ torch.save({"state_dict": new_state_dict}, os.path.join(CACHE_DIR, "clap.pth"))
+
+
+def build_model(ckpt_path=None, config=None, model_name="audioldm2-full"):
+ print("Loading AudioLDM-2: %s" % model_name)
+
+ if ckpt_path is None:
+ ckpt_path = get_metadata()[model_name]["path"]
+
+ if not os.path.exists(ckpt_path):
+ download_checkpoint(model_name)
+
+ if torch.cuda.is_available():
+ device = torch.device("cuda:0")
+ else:
+ device = torch.device("cpu")
+
+ if config is not None:
+ assert type(config) is str
+ config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
+ else:
+ config = default_audioldm_config(model_name)
+
+ # # Use text as condition instead of using waveform during training
+ config["model"]["params"]["device"] = device
+ # config["model"]["params"]["cond_stage_key"] = "text"
+
+ # No normalization here
+ latent_diffusion = LatentDiffusion(**config["model"]["params"])
+
+ resume_from_checkpoint = ckpt_path
+
+ checkpoint = torch.load(resume_from_checkpoint, map_location=device)
+
+ latent_diffusion.load_state_dict(checkpoint["state_dict"])
+
+ latent_diffusion.eval()
+ latent_diffusion = latent_diffusion.to(device)
+
+ return latent_diffusion
+
+def duration_to_latent_t_size(duration):
+ return int(duration * 25.6)
+
+def text_to_audio(
+ latent_diffusion,
+ text,
+ seed=42,
+ ddim_steps=200,
+ duration=10,
+ batchsize=1,
+ guidance_scale=3.5,
+ n_candidate_gen_per_text=3,
+ config=None,
+):
+ assert (
+ duration == 10
+ ), "Error: Currently we only support 10 seconds of generation. Generating longer files requires some extra coding, which would be a part of the future work."
+
+ seed_everything(int(seed))
+ waveform = None
+
+ batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize)
+
+ latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
+
+ with torch.no_grad():
+ waveform = latent_diffusion.generate_batch(
+ batch,
+ unconditional_guidance_scale=guidance_scale,
+ ddim_steps=ddim_steps,
+ n_gen=n_candidate_gen_per_text,
+ duration=duration,
+ )
+
+ return waveform
diff --git a/audioldm2/utilities/__init__.py b/audioldm2/utilities/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..495e8fe675337df0afacd3a31d06d0241b6b0e63
--- /dev/null
+++ b/audioldm2/utilities/__init__.py
@@ -0,0 +1,3 @@
+from .tools import *
+from .data import *
+from .model import *
diff --git a/audioldm2/utilities/audio/__init__.py b/audioldm2/utilities/audio/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..c39f9243d2d7b4fc5dea18f56b153b0f5c5bbd4c
--- /dev/null
+++ b/audioldm2/utilities/audio/__init__.py
@@ -0,0 +1,3 @@
+from .audio_processing import *
+from .stft import *
+from .tools import *
diff --git a/audioldm2/utilities/audio/audio_processing.py b/audioldm2/utilities/audio/audio_processing.py
new file mode 100755
index 0000000000000000000000000000000000000000..77a4057aa82f226f68474f4c2a19eba84510d663
--- /dev/null
+++ b/audioldm2/utilities/audio/audio_processing.py
@@ -0,0 +1,100 @@
+import torch
+import numpy as np
+import librosa.util as librosa_util
+from scipy.signal import get_window
+
+
+def window_sumsquare(
+ window,
+ n_frames,
+ hop_length,
+ win_length,
+ n_fft,
+ dtype=np.float32,
+ norm=None,
+):
+ """
+ # from librosa 0.6
+ Compute the sum-square envelope of a window function at a given hop length.
+
+ This is used to estimate modulation effects induced by windowing
+ observations in short-time fourier transforms.
+
+ Parameters
+ ----------
+ window : string, tuple, number, callable, or list-like
+ Window specification, as in `get_window`
+
+ n_frames : int > 0
+ The number of analysis frames
+
+ hop_length : int > 0
+ The number of samples to advance between frames
+
+ win_length : [optional]
+ The length of the window function. By default, this matches `n_fft`.
+
+ n_fft : int > 0
+ The length of each analysis frame.
+
+ dtype : np.dtype
+ The data type of the output
+
+ Returns
+ -------
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
+ The sum-squared envelope of the window function
+ """
+ if win_length is None:
+ win_length = n_fft
+
+ n = n_fft + hop_length * (n_frames - 1)
+ x = np.zeros(n, dtype=dtype)
+
+ # Compute the squared window at the desired length
+ win_sq = get_window(window, win_length, fftbins=True)
+ win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
+
+ # Fill the envelope
+ for i in range(n_frames):
+ sample = i * hop_length
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
+ return x
+
+
+def griffin_lim(magnitudes, stft_fn, n_iters=30):
+ """
+ PARAMS
+ ------
+ magnitudes: spectrogram magnitudes
+ stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
+ """
+
+ angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
+ angles = angles.astype(np.float32)
+ angles = torch.autograd.Variable(torch.from_numpy(angles))
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
+
+ for i in range(n_iters):
+ _, angles = stft_fn.transform(signal)
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
+ return signal
+
+
+def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
+ """
+ PARAMS
+ ------
+ C: compression factor
+ """
+ return normalize_fun(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+ """
+ PARAMS
+ ------
+ C: compression factor used to compress
+ """
+ return torch.exp(x) / C
diff --git a/audioldm2/utilities/audio/stft.py b/audioldm2/utilities/audio/stft.py
new file mode 100755
index 0000000000000000000000000000000000000000..508f33674e6dd8a5557205c8e77e07955df13a87
--- /dev/null
+++ b/audioldm2/utilities/audio/stft.py
@@ -0,0 +1,178 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy.signal import get_window
+from librosa.util import pad_center, tiny
+from librosa.filters import mel as librosa_mel_fn
+
+from audioldm2.utilities.audio.audio_processing import (
+ dynamic_range_compression,
+ dynamic_range_decompression,
+ window_sumsquare,
+)
+
+
+class STFT(torch.nn.Module):
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
+
+ def __init__(self, filter_length, hop_length, win_length, window="hann"):
+ super(STFT, self).__init__()
+ self.filter_length = filter_length
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.window = window
+ self.forward_transform = None
+ scale = self.filter_length / self.hop_length
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
+
+ cutoff = int((self.filter_length / 2 + 1))
+ fourier_basis = np.vstack(
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
+ )
+
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
+ inverse_basis = torch.FloatTensor(
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :]
+ )
+
+ if window is not None:
+ assert filter_length >= win_length
+ # get window and zero center pad it to filter_length
+ fft_window = get_window(window, win_length, fftbins=True)
+ fft_window = pad_center(fft_window, filter_length)
+ fft_window = torch.from_numpy(fft_window).float()
+
+ # window the bases
+ forward_basis *= fft_window
+ inverse_basis *= fft_window
+
+ self.register_buffer("forward_basis", forward_basis.float())
+ self.register_buffer("inverse_basis", inverse_basis.float())
+
+ def transform(self, input_data):
+ num_batches = input_data.size(0)
+ num_samples = input_data.size(1)
+
+ self.num_samples = num_samples
+
+ # similar to librosa, reflect-pad the input
+ input_data = input_data.view(num_batches, 1, num_samples)
+ input_data = F.pad(
+ input_data.unsqueeze(1),
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
+ mode="reflect",
+ )
+ input_data = input_data.squeeze(1)
+
+ forward_transform = F.conv1d(
+ input_data,
+ torch.autograd.Variable(self.forward_basis, requires_grad=False),
+ stride=self.hop_length,
+ padding=0,
+ ).cpu()
+
+ cutoff = int((self.filter_length / 2) + 1)
+ real_part = forward_transform[:, :cutoff, :]
+ imag_part = forward_transform[:, cutoff:, :]
+
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
+ phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
+
+ return magnitude, phase
+
+ def inverse(self, magnitude, phase):
+ recombine_magnitude_phase = torch.cat(
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
+ )
+
+ inverse_transform = F.conv_transpose1d(
+ recombine_magnitude_phase,
+ torch.autograd.Variable(self.inverse_basis, requires_grad=False),
+ stride=self.hop_length,
+ padding=0,
+ )
+
+ if self.window is not None:
+ window_sum = window_sumsquare(
+ self.window,
+ magnitude.size(-1),
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ n_fft=self.filter_length,
+ dtype=np.float32,
+ )
+ # remove modulation effects
+ approx_nonzero_indices = torch.from_numpy(
+ np.where(window_sum > tiny(window_sum))[0]
+ )
+ window_sum = torch.autograd.Variable(
+ torch.from_numpy(window_sum), requires_grad=False
+ )
+ window_sum = window_sum
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
+ approx_nonzero_indices
+ ]
+
+ # scale by hop ratio
+ inverse_transform *= float(self.filter_length) / self.hop_length
+
+ inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
+ inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
+
+ return inverse_transform
+
+ def forward(self, input_data):
+ self.magnitude, self.phase = self.transform(input_data)
+ reconstruction = self.inverse(self.magnitude, self.phase)
+ return reconstruction
+
+
+class TacotronSTFT(torch.nn.Module):
+ def __init__(
+ self,
+ filter_length,
+ hop_length,
+ win_length,
+ n_mel_channels,
+ sampling_rate,
+ mel_fmin,
+ mel_fmax,
+ ):
+ super(TacotronSTFT, self).__init__()
+ self.n_mel_channels = n_mel_channels
+ self.sampling_rate = sampling_rate
+ self.stft_fn = STFT(filter_length, hop_length, win_length)
+ mel_basis = librosa_mel_fn(
+ sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
+ )
+ mel_basis = torch.from_numpy(mel_basis).float()
+ self.register_buffer("mel_basis", mel_basis)
+
+ def spectral_normalize(self, magnitudes, normalize_fun):
+ output = dynamic_range_compression(magnitudes, normalize_fun)
+ return output
+
+ def spectral_de_normalize(self, magnitudes):
+ output = dynamic_range_decompression(magnitudes)
+ return output
+
+ def mel_spectrogram(self, y, normalize_fun=torch.log):
+ """Computes mel-spectrograms from a batch of waves
+ PARAMS
+ ------
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
+
+ RETURNS
+ -------
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
+ """
+ assert torch.min(y.data) >= -1, torch.min(y.data)
+ assert torch.max(y.data) <= 1, torch.max(y.data)
+
+ magnitudes, phases = self.stft_fn.transform(y)
+ magnitudes = magnitudes.data
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
+ mel_output = self.spectral_normalize(mel_output, normalize_fun)
+ energy = torch.norm(magnitudes, dim=1)
+
+ return mel_output, magnitudes, phases, energy
diff --git a/audioldm2/utilities/audio/tools.py b/audioldm2/utilities/audio/tools.py
new file mode 100755
index 0000000000000000000000000000000000000000..8c666a7c67e0ae93edbad666520fd2e98fd29d18
--- /dev/null
+++ b/audioldm2/utilities/audio/tools.py
@@ -0,0 +1,69 @@
+import torch
+import numpy as np
+from scipy.io.wavfile import write
+import torchaudio
+
+from audioldm2.utilities.audio.audio_processing import griffin_lim
+
+
+def pad_wav(waveform, segment_length):
+ waveform_length = waveform.shape[-1]
+ assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
+ if segment_length is None or waveform_length == segment_length:
+ return waveform
+ elif waveform_length > segment_length:
+ return waveform[:segment_length]
+ elif waveform_length < segment_length:
+ temp_wav = np.zeros((1, segment_length))
+ temp_wav[:, :waveform_length] = waveform
+ return temp_wav
+
+
+def normalize_wav(waveform):
+ waveform = waveform - np.mean(waveform)
+ waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
+ return waveform * 0.5
+
+
+def read_wav_file(filename, segment_length):
+ # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
+ waveform, sr = torchaudio.load(filename) # Faster!!!
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
+ waveform = waveform.numpy()[0, ...]
+ waveform = normalize_wav(waveform)
+ waveform = waveform[None, ...]
+ waveform = pad_wav(waveform, segment_length)
+
+ waveform = waveform / np.max(np.abs(waveform))
+ waveform = 0.5 * waveform
+
+ return waveform
+
+
+def get_mel_from_wav(audio, _stft):
+ audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
+ audio = torch.autograd.Variable(audio, requires_grad=False)
+ melspec, magnitudes, phases, energy = _stft.mel_spectrogram(audio)
+ melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
+ magnitudes = torch.squeeze(magnitudes, 0).numpy().astype(np.float32)
+ energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
+ return melspec, magnitudes, energy
+
+
+def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
+ mel = torch.stack([mel])
+ mel_decompress = _stft.spectral_de_normalize(mel)
+ mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
+ spec_from_mel_scaling = 1000
+ spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
+ spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
+ spec_from_mel = spec_from_mel * spec_from_mel_scaling
+
+ audio = griffin_lim(
+ torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
+ )
+
+ audio = audio.squeeze()
+ audio = audio.cpu().numpy()
+ audio_path = out_filename
+ write(audio_path, _stft.sampling_rate, audio)
diff --git a/audioldm2/utilities/data/__init__.py b/audioldm2/utilities/data/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..13a9804e72b88e3b9078940aee87db73788c1fb5
--- /dev/null
+++ b/audioldm2/utilities/data/__init__.py
@@ -0,0 +1 @@
+from .dataset import Dataset
diff --git a/audioldm2/utilities/data/add_on.py b/audioldm2/utilities/data/add_on.py
new file mode 100755
index 0000000000000000000000000000000000000000..4cfc6297e2f66759077c1540fc04b19560f3659c
--- /dev/null
+++ b/audioldm2/utilities/data/add_on.py
@@ -0,0 +1,508 @@
+import os
+import torch
+import numpy as np
+import torchaudio
+import matplotlib.pyplot as plt
+
+CACHE = {
+ "get_vits_phoneme_ids": {
+ "PAD_LENGTH": 310,
+ "_pad": "_",
+ "_punctuation": ';:,.!?¡¿—…"«»“” ',
+ "_letters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
+ "_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ",
+ "_special": "♪☎☒☝⚠",
+ }
+}
+
+CACHE["get_vits_phoneme_ids"]["symbols"] = (
+ [CACHE["get_vits_phoneme_ids"]["_pad"]]
+ + list(CACHE["get_vits_phoneme_ids"]["_punctuation"])
+ + list(CACHE["get_vits_phoneme_ids"]["_letters"])
+ + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"])
+ + list(CACHE["get_vits_phoneme_ids"]["_special"])
+)
+CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = {
+ s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"])
+}
+
+
+def get_vits_phoneme_ids(config, dl_output, metadata):
+ pad_token_id = 0
+ pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
+ _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]
+
+ assert (
+ "phonemes" in metadata.keys()
+ ), "You must provide vits phonemes on using addon get_vits_phoneme_ids"
+ clean_text = metadata["phonemes"]
+ sequence = []
+
+ for symbol in clean_text:
+ symbol_id = _symbol_to_id[symbol]
+ sequence += [symbol_id]
+
+ inserted_zero_sequence = [0] * (len(sequence) * 2)
+ inserted_zero_sequence[1::2] = sequence
+ inserted_zero_sequence = inserted_zero_sequence + [0]
+
+ def _pad_phonemes(phonemes_list):
+ return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list))
+
+ return {"phoneme_idx": torch.LongTensor(_pad_phonemes(inserted_zero_sequence))}
+
+
+def get_vits_phoneme_ids_no_padding(config, dl_output, metadata):
+ pad_token_id = 0
+ pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
+ _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]
+
+ assert (
+ "phonemes" in metadata.keys()
+ ), "You must provide vits phonemes on using addon get_vits_phoneme_ids"
+ clean_text = metadata["phonemes"] + "⚠"
+ sequence = []
+
+ for symbol in clean_text:
+ if symbol not in _symbol_to_id.keys():
+ print("%s is not in the vocabulary. %s" % (symbol, clean_text))
+ symbol = "_"
+ symbol_id = _symbol_to_id[symbol]
+ sequence += [symbol_id]
+
+ def _pad_phonemes(phonemes_list):
+ return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list))
+
+ sequence = sequence[:pad_length]
+
+ return {"phoneme_idx": torch.LongTensor(_pad_phonemes(sequence))}
+
+
+def calculate_relative_bandwidth(config, dl_output, metadata):
+ assert "stft" in dl_output.keys()
+
+ # The last dimension of the stft feature is the frequency dimension
+ freq_dimensions = dl_output["stft"].size(-1)
+
+ freq_energy_dist = torch.sum(dl_output["stft"], dim=0)
+ freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0)
+ total_energy = freq_energy_dist[-1]
+
+ percentile_5th = total_energy * 0.05
+ percentile_95th = total_energy * 0.95
+
+ lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist))
+ higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist))
+
+ lower_idx = int((lower_idx / freq_dimensions) * 1000)
+ higher_idx = int((higher_idx / freq_dimensions) * 1000)
+
+ return {"freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx])}
+
+
+def calculate_mel_spec_relative_bandwidth_as_extra_channel(config, dl_output, metadata):
+ assert "stft" in dl_output.keys()
+ linear_mel_spec = torch.exp(torch.clip(dl_output["log_mel_spec"], max=10))
+
+ # The last dimension of the stft feature is the frequency dimension
+ freq_dimensions = linear_mel_spec.size(-1)
+ freq_energy_dist = torch.sum(linear_mel_spec, dim=0)
+ freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0)
+ total_energy = freq_energy_dist[-1]
+
+ percentile_5th = total_energy * 0.05
+ percentile_95th = total_energy * 0.95
+
+ lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist))
+ higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist))
+
+ latent_t_size = config["model"]["params"]["latent_t_size"]
+ latent_f_size = config["model"]["params"]["latent_f_size"]
+
+ lower_idx = int(latent_f_size * float((lower_idx / freq_dimensions)))
+ higher_idx = int(latent_f_size * float((higher_idx / freq_dimensions)))
+
+ bandwidth_condition = torch.zeros((latent_t_size, latent_f_size))
+ bandwidth_condition[:, lower_idx:higher_idx] += 1.0
+
+ return {
+ "mel_spec_bandwidth_cond_extra_channel": bandwidth_condition,
+ "freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx]),
+ }
+
+
+def waveform_rs_48k(config, dl_output, metadata):
+ waveform = dl_output["waveform"] # [1, samples]
+ sampling_rate = dl_output["sampling_rate"]
+
+ if sampling_rate != 48000:
+ waveform_48k = torchaudio.functional.resample(
+ waveform, orig_freq=sampling_rate, new_freq=48000
+ )
+ else:
+ waveform_48k = waveform
+
+ return {"waveform_48k": waveform_48k}
+
+
+def extract_vits_phoneme_and_flant5_text(config, dl_output, metadata):
+ assert (
+ "phoneme" not in metadata.keys()
+ ), "The metadata of speech you use seems belong to fastspeech. Please check dataset_root.json"
+
+ if "phonemes" in metadata.keys():
+ new_item = get_vits_phoneme_ids_no_padding(config, dl_output, metadata)
+ new_item["text"] = "" # We assume TTS data does not have text description
+ else:
+ fake_metadata = {"phonemes": ""} # Add empty phoneme sequence
+ new_item = get_vits_phoneme_ids_no_padding(config, dl_output, fake_metadata)
+
+ return new_item
+
+
+def extract_fs2_phoneme_and_flant5_text(config, dl_output, metadata):
+ if "phoneme" in metadata.keys():
+ new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata)
+ new_item["text"] = ""
+ else:
+ fake_metadata = {"phoneme": []}
+ new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, fake_metadata)
+ return new_item
+
+
+def extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata):
+ PAD_LENGTH = 135
+
+ phonemes_lookup_dict = {
+ "K": 0,
+ "IH2": 1,
+ "NG": 2,
+ "OW2": 3,
+ "AH2": 4,
+ "F": 5,
+ "AE0": 6,
+ "IY0": 7,
+ "SH": 8,
+ "G": 9,
+ "W": 10,
+ "UW1": 11,
+ "AO2": 12,
+ "AW2": 13,
+ "UW0": 14,
+ "EY2": 15,
+ "UW2": 16,
+ "AE2": 17,
+ "IH0": 18,
+ "P": 19,
+ "D": 20,
+ "ER1": 21,
+ "AA1": 22,
+ "EH0": 23,
+ "UH1": 24,
+ "N": 25,
+ "V": 26,
+ "AY1": 27,
+ "EY1": 28,
+ "UH2": 29,
+ "EH1": 30,
+ "L": 31,
+ "AA2": 32,
+ "R": 33,
+ "OY1": 34,
+ "Y": 35,
+ "ER2": 36,
+ "S": 37,
+ "AE1": 38,
+ "AH1": 39,
+ "JH": 40,
+ "ER0": 41,
+ "EH2": 42,
+ "IY2": 43,
+ "OY2": 44,
+ "AW1": 45,
+ "IH1": 46,
+ "IY1": 47,
+ "OW0": 48,
+ "AO0": 49,
+ "AY0": 50,
+ "EY0": 51,
+ "AY2": 52,
+ "UH0": 53,
+ "M": 54,
+ "TH": 55,
+ "T": 56,
+ "OY0": 57,
+ "AW0": 58,
+ "DH": 59,
+ "Z": 60,
+ "spn": 61,
+ "AH0": 62,
+ "sp": 63,
+ "AO1": 64,
+ "OW1": 65,
+ "ZH": 66,
+ "B": 67,
+ "AA0": 68,
+ "CH": 69,
+ "HH": 70,
+ }
+ pad_token_id = len(phonemes_lookup_dict.keys())
+
+ assert (
+ "phoneme" in metadata.keys()
+ ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset"
+
+ phonemes = [
+ phonemes_lookup_dict[x]
+ for x in metadata["phoneme"]
+ if (x in phonemes_lookup_dict.keys())
+ ]
+
+ if (len(phonemes) / PAD_LENGTH) > 5:
+ print(
+ "Warning: Phonemes length is too long and is truncated too much! %s"
+ % metadata
+ )
+
+ phonemes = phonemes[:PAD_LENGTH]
+
+ def _pad_phonemes(phonemes_list):
+ return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list))
+
+ return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))}
+
+
+def extract_phoneme_g2p_en_feature(config, dl_output, metadata):
+ PAD_LENGTH = 250
+
+ phonemes_lookup_dict = {
+ " ": 0,
+ "AA": 1,
+ "AE": 2,
+ "AH": 3,
+ "AO": 4,
+ "AW": 5,
+ "AY": 6,
+ "B": 7,
+ "CH": 8,
+ "D": 9,
+ "DH": 10,
+ "EH": 11,
+ "ER": 12,
+ "EY": 13,
+ "F": 14,
+ "G": 15,
+ "HH": 16,
+ "IH": 17,
+ "IY": 18,
+ "JH": 19,
+ "K": 20,
+ "L": 21,
+ "M": 22,
+ "N": 23,
+ "NG": 24,
+ "OW": 25,
+ "OY": 26,
+ "P": 27,
+ "R": 28,
+ "S": 29,
+ "SH": 30,
+ "T": 31,
+ "TH": 32,
+ "UH": 33,
+ "UW": 34,
+ "V": 35,
+ "W": 36,
+ "Y": 37,
+ "Z": 38,
+ "ZH": 39,
+ }
+ pad_token_id = len(phonemes_lookup_dict.keys())
+
+ assert (
+ "phoneme" in metadata.keys()
+ ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset"
+ phonemes = [
+ phonemes_lookup_dict[x]
+ for x in metadata["phoneme"]
+ if (x in phonemes_lookup_dict.keys())
+ ]
+
+ if (len(phonemes) / PAD_LENGTH) > 5:
+ print(
+ "Warning: Phonemes length is too long and is truncated too much! %s"
+ % metadata
+ )
+
+ phonemes = phonemes[:PAD_LENGTH]
+
+ def _pad_phonemes(phonemes_list):
+ return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list))
+
+ return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))}
+
+
+def extract_kaldi_fbank_feature(config, dl_output, metadata):
+ norm_mean = -4.2677393
+ norm_std = 4.5689974
+
+ waveform = dl_output["waveform"] # [1, samples]
+ sampling_rate = dl_output["sampling_rate"]
+ log_mel_spec_hifigan = dl_output["log_mel_spec"]
+
+ if sampling_rate != 16000:
+ waveform_16k = torchaudio.functional.resample(
+ waveform, orig_freq=sampling_rate, new_freq=16000
+ )
+ else:
+ waveform_16k = waveform
+
+ waveform_16k = waveform_16k - waveform_16k.mean()
+ fbank = torchaudio.compliance.kaldi.fbank(
+ waveform_16k,
+ htk_compat=True,
+ sample_frequency=16000,
+ use_energy=False,
+ window_type="hanning",
+ num_mel_bins=128,
+ dither=0.0,
+ frame_shift=10,
+ )
+
+ TARGET_LEN = log_mel_spec_hifigan.size(0)
+
+ # cut and pad
+ n_frames = fbank.shape[0]
+ p = TARGET_LEN - n_frames
+ if p > 0:
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
+ fbank = m(fbank)
+ elif p < 0:
+ fbank = fbank[:TARGET_LEN, :]
+
+ fbank = (fbank - norm_mean) / (norm_std * 2)
+
+ return {"ta_kaldi_fbank": fbank} # [1024, 128]
+
+
+def extract_kaldi_fbank_feature_32k(config, dl_output, metadata):
+ norm_mean = -4.2677393
+ norm_std = 4.5689974
+
+ waveform = dl_output["waveform"] # [1, samples]
+ sampling_rate = dl_output["sampling_rate"]
+ log_mel_spec_hifigan = dl_output["log_mel_spec"]
+
+ if sampling_rate != 32000:
+ waveform_32k = torchaudio.functional.resample(
+ waveform, orig_freq=sampling_rate, new_freq=32000
+ )
+ else:
+ waveform_32k = waveform
+
+ waveform_32k = waveform_32k - waveform_32k.mean()
+ fbank = torchaudio.compliance.kaldi.fbank(
+ waveform_32k,
+ htk_compat=True,
+ sample_frequency=32000,
+ use_energy=False,
+ window_type="hanning",
+ num_mel_bins=128,
+ dither=0.0,
+ frame_shift=10,
+ )
+
+ TARGET_LEN = log_mel_spec_hifigan.size(0)
+
+ # cut and pad
+ n_frames = fbank.shape[0]
+ p = TARGET_LEN - n_frames
+ if p > 0:
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
+ fbank = m(fbank)
+ elif p < 0:
+ fbank = fbank[:TARGET_LEN, :]
+
+ fbank = (fbank - norm_mean) / (norm_std * 2)
+
+ return {"ta_kaldi_fbank": fbank} # [1024, 128]
+
+
+# Use the beat and downbeat information as music conditions
+def extract_drum_beat(config, dl_output, metadata):
+ def visualization(conditional_signal, mel_spectrogram, filename):
+ import soundfile as sf
+
+ sf.write(
+ os.path.basename(dl_output["fname"]),
+ np.array(dl_output["waveform"])[0],
+ dl_output["sampling_rate"],
+ )
+ plt.figure(figsize=(10, 10))
+
+ plt.subplot(211)
+ plt.imshow(np.array(conditional_signal).T, aspect="auto")
+ plt.title("Conditional Signal")
+
+ plt.subplot(212)
+ plt.imshow(np.array(mel_spectrogram).T, aspect="auto")
+ plt.title("Mel Spectrogram")
+
+ plt.savefig(filename)
+ plt.close()
+
+ assert "sample_rate" in metadata and "beat" in metadata and "downbeat" in metadata
+
+ sampling_rate = metadata["sample_rate"]
+ duration = dl_output["duration"]
+ # The dataloader segment length before performing torch resampling
+ original_segment_length_before_resample = int(sampling_rate * duration)
+
+ random_start_sample = int(dl_output["random_start_sample_in_original_audio_file"])
+
+ # The sample idx for beat and downbeat, relatively to the segmented audio
+ beat = [
+ x - random_start_sample
+ for x in metadata["beat"]
+ if (
+ x - random_start_sample >= 0
+ and x - random_start_sample <= original_segment_length_before_resample
+ )
+ ]
+ downbeat = [
+ x - random_start_sample
+ for x in metadata["downbeat"]
+ if (
+ x - random_start_sample >= 0
+ and x - random_start_sample <= original_segment_length_before_resample
+ )
+ ]
+
+ latent_shape = (
+ config["model"]["params"]["latent_t_size"],
+ config["model"]["params"]["latent_f_size"],
+ )
+ conditional_signal = torch.zeros(latent_shape)
+
+ # beat: -0.5
+ # downbeat: +1.0
+ # 0: none; -0.5: beat; 1.0: downbeat; 0.5: downbeat+beat
+ for each in beat:
+ beat_index = int(
+ (each / original_segment_length_before_resample) * latent_shape[0]
+ )
+ beat_index = min(beat_index, conditional_signal.size(0) - 1)
+
+ conditional_signal[beat_index, :] -= 0.5
+
+ for each in downbeat:
+ beat_index = int(
+ (each / original_segment_length_before_resample) * latent_shape[0]
+ )
+ beat_index = min(beat_index, conditional_signal.size(0) - 1)
+
+ conditional_signal[beat_index, :] += 1.0
+
+ # visualization(conditional_signal, dl_output["log_mel_spec"], filename = os.path.basename(dl_output["fname"])+".png")
+
+ return {"cond_beat_downbeat": conditional_signal}
diff --git a/audioldm2/utilities/data/dataset.py b/audioldm2/utilities/data/dataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..f0bfbb7388ca6473beb4574ac4e29dcf0b7c0571
--- /dev/null
+++ b/audioldm2/utilities/data/dataset.py
@@ -0,0 +1,518 @@
+import os
+import pandas as pd
+
+import audioldm2.utilities.audio as Audio
+from audioldm2.utilities.tools import load_json
+
+import random
+from torch.utils.data import Dataset
+import torch.nn.functional
+import torch
+import numpy as np
+import torchaudio
+
+
+class AudioDataset(Dataset):
+ def __init__(
+ self,
+ config=None,
+ split="train",
+ waveform_only=False,
+ add_ons=[],
+ dataset_json_path=None, #
+ ):
+ """
+ Dataset that manages audio recordings
+ :param audio_conf: Dictionary containing the audio loading and preprocessing settings
+ :param dataset_json_file
+ """
+ self.config = config
+ self.split = split
+ self.pad_wav_start_sample = 0 # If none, random choose
+ self.trim_wav = False
+ self.waveform_only = waveform_only
+ self.add_ons = [eval(x) for x in add_ons]
+ print("Add-ons:", self.add_ons)
+
+ self.build_setting_parameters()
+
+ # For an external dataset
+ if dataset_json_path is not None:
+ assert type(dataset_json_path) == str
+ print("Load metadata from %s" % dataset_json_path)
+ self.data = load_json(dataset_json_path)["data"]
+ self.id2label, self.index_dict, self.num2label = {}, {}, {}
+ else:
+ self.metadata_root = load_json(self.config["metadata_root"])
+ self.dataset_name = self.config["data"][self.split]
+ assert split in self.config["data"].keys(), (
+ "The dataset split %s you specified is not present in the config. You can choose from %s"
+ % (split, self.config["data"].keys())
+ )
+ self.build_dataset()
+ self.build_id_to_label()
+
+ self.build_dsp()
+ self.label_num = len(self.index_dict)
+ print("Dataset initialize finished")
+
+ def __getitem__(self, index):
+ (
+ fname,
+ waveform,
+ stft,
+ log_mel_spec,
+ label_vector, # the one-hot representation of the audio class
+ # the metadata of the sampled audio file and the mixup audio file (if exist)
+ (datum, mix_datum),
+ random_start,
+ ) = self.feature_extraction(index)
+ text = self.get_sample_text_caption(datum, mix_datum, label_vector)
+
+ data = {
+ "text": text, # list
+ "fname": self.text_to_filename(text)
+ if (len(fname) == 0)
+ else fname, # list
+ # tensor, [batchsize, class_num]
+ "label_vector": "" if (label_vector is None) else label_vector.float(),
+ # tensor, [batchsize, 1, samples_num]
+ "waveform": "" if (waveform is None) else waveform.float(),
+ # tensor, [batchsize, t-steps, f-bins]
+ "stft": "" if (stft is None) else stft.float(),
+ # tensor, [batchsize, t-steps, mel-bins]
+ "log_mel_spec": "" if (log_mel_spec is None) else log_mel_spec.float(),
+ "duration": self.duration,
+ "sampling_rate": self.sampling_rate,
+ "random_start_sample_in_original_audio_file": random_start,
+ }
+
+ for add_on in self.add_ons:
+ data.update(add_on(self.config, data, self.data[index]))
+
+ if data["text"] is None:
+ print("Warning: The model return None on key text", fname)
+ data["text"] = ""
+
+ return data
+
+ def text_to_filename(self, text):
+ return text.replace(" ", "_").replace("'", "_").replace('"', "_")
+
+ def get_dataset_root_path(self, dataset):
+ assert dataset in self.metadata_root.keys()
+ return self.metadata_root[dataset]
+
+ def get_dataset_metadata_path(self, dataset, key):
+ # key: train, test, val, class_label_indices
+ try:
+ if dataset in self.metadata_root["metadata"]["path"].keys():
+ return self.metadata_root["metadata"]["path"][dataset][key]
+ except:
+ raise ValueError(
+ 'Dataset %s does not metadata "%s" specified' % (dataset, key)
+ )
+ # return None
+
+ def __len__(self):
+ return len(self.data)
+
+ def feature_extraction(self, index):
+ if index > len(self.data) - 1:
+ print(
+ "The index of the dataloader is out of range: %s/%s"
+ % (index, len(self.data))
+ )
+ index = random.randint(0, len(self.data) - 1)
+
+ # Read wave file and extract feature
+ while True:
+ try:
+ label_indices = np.zeros(self.label_num, dtype=np.float32)
+ datum = self.data[index]
+ (
+ log_mel_spec,
+ stft,
+ mix_lambda,
+ waveform,
+ random_start,
+ ) = self.read_audio_file(datum["wav"])
+ mix_datum = None
+ if self.label_num > 0 and "labels" in datum.keys():
+ for label_str in datum["labels"].split(","):
+ label_indices[int(self.index_dict[label_str])] = 1.0
+
+ # If the key "label" is not in the metadata, return all zero vector
+ label_indices = torch.FloatTensor(label_indices)
+ break
+ except Exception as e:
+ index = (index + 1) % len(self.data)
+ print(
+ "Error encounter during audio feature extraction: ", e, datum["wav"]
+ )
+ continue
+
+ # The filename of the wav file
+ fname = datum["wav"]
+ # t_step = log_mel_spec.size(0)
+ # waveform = torch.FloatTensor(waveform[..., : int(self.hopsize * t_step)])
+ waveform = torch.FloatTensor(waveform)
+
+ return (
+ fname,
+ waveform,
+ stft,
+ log_mel_spec,
+ label_indices,
+ (datum, mix_datum),
+ random_start,
+ )
+
+ # def augmentation(self, log_mel_spec):
+ # assert torch.min(log_mel_spec) < 0
+ # log_mel_spec = log_mel_spec.exp()
+
+ # log_mel_spec = torch.transpose(log_mel_spec, 0, 1)
+ # # this is just to satisfy new torchaudio version.
+ # log_mel_spec = log_mel_spec.unsqueeze(0)
+ # if self.freqm != 0:
+ # log_mel_spec = self.frequency_masking(log_mel_spec, self.freqm)
+ # if self.timem != 0:
+ # log_mel_spec = self.time_masking(
+ # log_mel_spec, self.timem) # self.timem=0
+
+ # log_mel_spec = (log_mel_spec + 1e-7).log()
+ # # squeeze back
+ # log_mel_spec = log_mel_spec.squeeze(0)
+ # log_mel_spec = torch.transpose(log_mel_spec, 0, 1)
+ # return log_mel_spec
+
+ def build_setting_parameters(self):
+ # Read from the json config
+ self.melbins = self.config["preprocessing"]["mel"]["n_mel_channels"]
+ # self.freqm = self.config["preprocessing"]["mel"]["freqm"]
+ # self.timem = self.config["preprocessing"]["mel"]["timem"]
+ self.sampling_rate = self.config["preprocessing"]["audio"]["sampling_rate"]
+ self.hopsize = self.config["preprocessing"]["stft"]["hop_length"]
+ self.duration = self.config["preprocessing"]["audio"]["duration"]
+ self.target_length = int(self.duration * self.sampling_rate / self.hopsize)
+
+ self.mixup = self.config["augmentation"]["mixup"]
+
+ # Calculate parameter derivations
+ # self.waveform_sample_length = int(self.target_length * self.hopsize)
+
+ # if (self.config["balance_sampling_weight"]):
+ # self.samples_weight = np.loadtxt(
+ # self.config["balance_sampling_weight"], delimiter=","
+ # )
+
+ if "train" not in self.split:
+ self.mixup = 0.0
+ # self.freqm = 0
+ # self.timem = 0
+
+ def _relative_path_to_absolute_path(self, metadata, dataset_name):
+ root_path = self.get_dataset_root_path(dataset_name)
+ for i in range(len(metadata["data"])):
+ assert "wav" in metadata["data"][i].keys(), metadata["data"][i]
+ assert metadata["data"][i]["wav"][0] != "/", (
+ "The dataset metadata should only contain relative path to the audio file: "
+ + str(metadata["data"][i]["wav"])
+ )
+ metadata["data"][i]["wav"] = os.path.join(
+ root_path, metadata["data"][i]["wav"]
+ )
+ return metadata
+
+ def build_dataset(self):
+ self.data = []
+ print("Build dataset split %s from %s" % (self.split, self.dataset_name))
+ if type(self.dataset_name) is str:
+ data_json = load_json(
+ self.get_dataset_metadata_path(self.dataset_name, key=self.split)
+ )
+ data_json = self._relative_path_to_absolute_path(
+ data_json, self.dataset_name
+ )
+ self.data = data_json["data"]
+ elif type(self.dataset_name) is list:
+ for dataset_name in self.dataset_name:
+ data_json = load_json(
+ self.get_dataset_metadata_path(dataset_name, key=self.split)
+ )
+ data_json = self._relative_path_to_absolute_path(
+ data_json, dataset_name
+ )
+ self.data += data_json["data"]
+ else:
+ raise Exception("Invalid data format")
+ print("Data size: {}".format(len(self.data)))
+
+ def build_dsp(self):
+ self.STFT = Audio.stft.TacotronSTFT(
+ self.config["preprocessing"]["stft"]["filter_length"],
+ self.config["preprocessing"]["stft"]["hop_length"],
+ self.config["preprocessing"]["stft"]["win_length"],
+ self.config["preprocessing"]["mel"]["n_mel_channels"],
+ self.config["preprocessing"]["audio"]["sampling_rate"],
+ self.config["preprocessing"]["mel"]["mel_fmin"],
+ self.config["preprocessing"]["mel"]["mel_fmax"],
+ )
+ # self.stft_transform = torchaudio.transforms.Spectrogram(
+ # n_fft=1024, hop_length=160
+ # )
+ # self.melscale_transform = torchaudio.transforms.MelScale(
+ # sample_rate=16000, n_stft=1024 // 2 + 1, n_mels=64
+ # )
+
+ def build_id_to_label(self):
+ id2label = {}
+ id2num = {}
+ num2label = {}
+ class_label_indices_path = self.get_dataset_metadata_path(
+ dataset=self.config["data"]["class_label_indices"],
+ key="class_label_indices",
+ )
+ if class_label_indices_path is not None:
+ df = pd.read_csv(class_label_indices_path)
+ for _, row in df.iterrows():
+ index, mid, display_name = row["index"], row["mid"], row["display_name"]
+ id2label[mid] = display_name
+ id2num[mid] = index
+ num2label[index] = display_name
+ self.id2label, self.index_dict, self.num2label = id2label, id2num, num2label
+ else:
+ self.id2label, self.index_dict, self.num2label = {}, {}, {}
+
+ def resample(self, waveform, sr):
+ waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate)
+ # waveform = librosa.resample(waveform, sr, self.sampling_rate)
+ return waveform
+
+ # if sr == 16000:
+ # return waveform
+ # if sr == 32000 and self.sampling_rate == 16000:
+ # waveform = waveform[::2]
+ # return waveform
+ # if sr == 48000 and self.sampling_rate == 16000:
+ # waveform = waveform[::3]
+ # return waveform
+ # else:
+ # raise ValueError(
+ # "We currently only support 16k audio generation. You need to resample you audio file to 16k, 32k, or 48k: %s, %s"
+ # % (sr, self.sampling_rate)
+ # )
+
+ def normalize_wav(self, waveform):
+ waveform = waveform - np.mean(waveform)
+ waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
+ return waveform * 0.5 # Manually limit the maximum amplitude into 0.5
+
+ def random_segment_wav(self, waveform, target_length):
+ waveform_length = waveform.shape[-1]
+ assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
+
+ # Too short
+ if (waveform_length - target_length) <= 0:
+ return waveform, 0
+
+ random_start = int(self.random_uniform(0, waveform_length - target_length))
+ return waveform[:, random_start : random_start + target_length], random_start
+
+ def pad_wav(self, waveform, target_length):
+ waveform_length = waveform.shape[-1]
+ assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
+
+ if waveform_length == target_length:
+ return waveform
+
+ # Pad
+ temp_wav = np.zeros((1, target_length), dtype=np.float32)
+ if self.pad_wav_start_sample is None:
+ rand_start = int(self.random_uniform(0, target_length - waveform_length))
+ else:
+ rand_start = 0
+
+ temp_wav[:, rand_start : rand_start + waveform_length] = waveform
+ return temp_wav
+
+ def trim_wav(self, waveform):
+ if np.max(np.abs(waveform)) < 0.0001:
+ return waveform
+
+ def detect_leading_silence(waveform, threshold=0.0001):
+ chunk_size = 1000
+ waveform_length = waveform.shape[0]
+ start = 0
+ while start + chunk_size < waveform_length:
+ if np.max(np.abs(waveform[start : start + chunk_size])) < threshold:
+ start += chunk_size
+ else:
+ break
+ return start
+
+ def detect_ending_silence(waveform, threshold=0.0001):
+ chunk_size = 1000
+ waveform_length = waveform.shape[0]
+ start = waveform_length
+ while start - chunk_size > 0:
+ if np.max(np.abs(waveform[start - chunk_size : start])) < threshold:
+ start -= chunk_size
+ else:
+ break
+ if start == waveform_length:
+ return start
+ else:
+ return start + chunk_size
+
+ start = detect_leading_silence(waveform)
+ end = detect_ending_silence(waveform)
+
+ return waveform[start:end]
+
+ def read_wav_file(self, filename):
+ # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
+ waveform, sr = torchaudio.load(filename)
+
+ waveform, random_start = self.random_segment_wav(
+ waveform, target_length=int(sr * self.duration)
+ )
+
+ waveform = self.resample(waveform, sr)
+ # random_start = int(random_start * (self.sampling_rate / sr))
+
+ waveform = waveform.numpy()[0, ...]
+
+ waveform = self.normalize_wav(waveform)
+
+ if self.trim_wav:
+ waveform = self.trim_wav(waveform)
+
+ waveform = waveform[None, ...]
+ waveform = self.pad_wav(
+ waveform, target_length=int(self.sampling_rate * self.duration)
+ )
+ return waveform, random_start
+
+ def mix_two_waveforms(self, waveform1, waveform2):
+ mix_lambda = np.random.beta(5, 5)
+ mix_waveform = mix_lambda * waveform1 + (1 - mix_lambda) * waveform2
+ return self.normalize_wav(mix_waveform), mix_lambda
+
+ def read_audio_file(self, filename, filename2=None):
+ if os.path.exists(filename):
+ waveform, random_start = self.read_wav_file(filename)
+ else:
+ print(
+ 'Warning [dataset.py]: The wav path "',
+ filename,
+ '" is not find in the metadata. Use empty waveform instead.',
+ )
+ target_length = int(self.sampling_rate * self.duration)
+ waveform = torch.zeros((1, target_length))
+ random_start = 0
+
+ mix_lambda = 0.0
+ # log_mel_spec, stft = self.wav_feature_extraction_torchaudio(waveform) # this line is faster, but this implementation is not aligned with HiFi-GAN
+ if not self.waveform_only:
+ log_mel_spec, stft = self.wav_feature_extraction(waveform)
+ else:
+ # Load waveform data only
+ # Use zero array to keep the format unified
+ log_mel_spec, stft = None, None
+
+ return log_mel_spec, stft, mix_lambda, waveform, random_start
+
+ def get_sample_text_caption(self, datum, mix_datum, label_indices):
+ text = self.label_indices_to_text(datum, label_indices)
+ if mix_datum is not None:
+ text += " " + self.label_indices_to_text(mix_datum, label_indices)
+ return text
+
+ # This one is significantly slower than "wav_feature_extraction_torchaudio" if num_worker > 1
+ def wav_feature_extraction(self, waveform):
+ waveform = waveform[0, ...]
+ waveform = torch.FloatTensor(waveform)
+
+ log_mel_spec, stft, energy = Audio.tools.get_mel_from_wav(waveform, self.STFT)
+
+ log_mel_spec = torch.FloatTensor(log_mel_spec.T)
+ stft = torch.FloatTensor(stft.T)
+
+ log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft)
+ return log_mel_spec, stft
+
+ # @profile
+ # def wav_feature_extraction_torchaudio(self, waveform):
+ # waveform = waveform[0, ...]
+ # waveform = torch.FloatTensor(waveform)
+
+ # stft = self.stft_transform(waveform)
+ # mel_spec = self.melscale_transform(stft)
+ # log_mel_spec = torch.log(mel_spec + 1e-7)
+
+ # log_mel_spec = torch.FloatTensor(log_mel_spec.T)
+ # stft = torch.FloatTensor(stft.T)
+
+ # log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft)
+ # return log_mel_spec, stft
+
+ def pad_spec(self, log_mel_spec):
+ n_frames = log_mel_spec.shape[0]
+ p = self.target_length - n_frames
+ # cut and pad
+ if p > 0:
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
+ log_mel_spec = m(log_mel_spec)
+ elif p < 0:
+ log_mel_spec = log_mel_spec[0 : self.target_length, :]
+
+ if log_mel_spec.size(-1) % 2 != 0:
+ log_mel_spec = log_mel_spec[..., :-1]
+
+ return log_mel_spec
+
+ def _read_datum_caption(self, datum):
+ caption_keys = [x for x in datum.keys() if ("caption" in x)]
+ random_index = torch.randint(0, len(caption_keys), (1,))[0].item()
+ return datum[caption_keys[random_index]]
+
+ def _is_contain_caption(self, datum):
+ caption_keys = [x for x in datum.keys() if ("caption" in x)]
+ return len(caption_keys) > 0
+
+ def label_indices_to_text(self, datum, label_indices):
+ if self._is_contain_caption(datum):
+ return self._read_datum_caption(datum)
+ elif "label" in datum.keys():
+ name_indices = torch.where(label_indices > 0.1)[0]
+ # description_header = "This audio contains the sound of "
+ description_header = ""
+ labels = ""
+ for id, each in enumerate(name_indices):
+ if id == len(name_indices) - 1:
+ labels += "%s." % self.num2label[int(each)]
+ else:
+ labels += "%s, " % self.num2label[int(each)]
+ return description_header + labels
+ else:
+ return "" # TODO, if both label and caption are not provided, return empty string
+
+ def random_uniform(self, start, end):
+ val = torch.rand(1).item()
+ return start + (end - start) * val
+
+ def frequency_masking(self, log_mel_spec, freqm):
+ bs, freq, tsteps = log_mel_spec.size()
+ mask_len = int(self.random_uniform(freqm // 8, freqm))
+ mask_start = int(self.random_uniform(start=0, end=freq - mask_len))
+ log_mel_spec[:, mask_start : mask_start + mask_len, :] *= 0.0
+ return log_mel_spec
+
+ def time_masking(self, log_mel_spec, timem):
+ bs, freq, tsteps = log_mel_spec.size()
+ mask_len = int(self.random_uniform(timem // 8, timem))
+ mask_start = int(self.random_uniform(start=0, end=tsteps - mask_len))
+ log_mel_spec[:, :, mask_start : mask_start + mask_len] *= 0.0
+ return log_mel_spec
diff --git a/audioldm2/utilities/model.py b/audioldm2/utilities/model.py
new file mode 100755
index 0000000000000000000000000000000000000000..ffefac1212b85bfb8c4992371dbdf6d500a969e3
--- /dev/null
+++ b/audioldm2/utilities/model.py
@@ -0,0 +1,121 @@
+import torch
+
+import audioldm2.hifigan as hifigan
+
+
+def get_vocoder_config():
+ return {
+ "resblock": "1",
+ "num_gpus": 6,
+ "batch_size": 16,
+ "learning_rate": 0.0002,
+ "adam_b1": 0.8,
+ "adam_b2": 0.99,
+ "lr_decay": 0.999,
+ "seed": 1234,
+ "upsample_rates": [5, 4, 2, 2, 2],
+ "upsample_kernel_sizes": [16, 16, 8, 4, 4],
+ "upsample_initial_channel": 1024,
+ "resblock_kernel_sizes": [3, 7, 11],
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ "segment_size": 8192,
+ "num_mels": 64,
+ "num_freq": 1025,
+ "n_fft": 1024,
+ "hop_size": 160,
+ "win_size": 1024,
+ "sampling_rate": 16000,
+ "fmin": 0,
+ "fmax": 8000,
+ "fmax_for_loss": None,
+ "num_workers": 4,
+ "dist_config": {
+ "dist_backend": "nccl",
+ "dist_url": "tcp://localhost:54321",
+ "world_size": 1,
+ },
+ }
+
+
+def get_available_checkpoint_keys(model, ckpt):
+ state_dict = torch.load(ckpt)["state_dict"]
+ current_state_dict = model.state_dict()
+ new_state_dict = {}
+ for k in state_dict.keys():
+ if (
+ k in current_state_dict.keys()
+ and current_state_dict[k].size() == state_dict[k].size()
+ ):
+ new_state_dict[k] = state_dict[k]
+ else:
+ print("==> WARNING: Skipping %s" % k)
+ print(
+ "%s out of %s keys are matched"
+ % (len(new_state_dict.keys()), len(state_dict.keys()))
+ )
+ return new_state_dict
+
+
+def get_param_num(model):
+ num_param = sum(param.numel() for param in model.parameters())
+ return num_param
+
+
+def torch_version_orig_mod_remove(state_dict):
+ new_state_dict = {}
+ new_state_dict["generator"] = {}
+ for key in state_dict["generator"].keys():
+ if "_orig_mod." in key:
+ new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[
+ "generator"
+ ][key]
+ else:
+ new_state_dict["generator"][key] = state_dict["generator"][key]
+ return new_state_dict
+
+
+def get_vocoder(config, device, mel_bins):
+ name = "HiFi-GAN"
+ speaker = ""
+ if name == "MelGAN":
+ if speaker == "LJSpeech":
+ vocoder = torch.hub.load(
+ "descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
+ )
+ elif speaker == "universal":
+ vocoder = torch.hub.load(
+ "descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
+ )
+ vocoder.mel2wav.eval()
+ vocoder.mel2wav.to(device)
+ elif name == "HiFi-GAN":
+ config = get_vocoder_config()
+ config = hifigan.AttrDict(config)
+ vocoder = hifigan.Generator_old(config)
+ # print("Load hifigan/g_01080000")
+ # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
+ # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
+ # ckpt = torch_version_orig_mod_remove(ckpt)
+ # vocoder.load_state_dict(ckpt["generator"])
+ vocoder.eval()
+ vocoder.remove_weight_norm()
+ vocoder.to(device)
+ return vocoder
+
+
+def vocoder_infer(mels, vocoder, lengths=None):
+ with torch.no_grad():
+ wavs = vocoder(mels).squeeze(1)
+
+ wavs = (wavs.cpu().numpy() * 32768).astype("int16")
+
+ if lengths is not None:
+ wavs = wavs[:, :lengths]
+
+ # wavs = [wav for wav in wavs]
+
+ # for i in range(len(mels)):
+ # if lengths is not None:
+ # wavs[i] = wavs[i][: lengths[i]]
+
+ return wavs
diff --git a/audioldm2/utilities/sampler.py b/audioldm2/utilities/sampler.py
new file mode 100755
index 0000000000000000000000000000000000000000..cdaf4882715f53f39ead8bf71fb3dccc29cd8b94
--- /dev/null
+++ b/audioldm2/utilities/sampler.py
@@ -0,0 +1,588 @@
+from typing import Iterator, List, Optional, Union
+from collections import Counter
+import logging
+from operator import itemgetter
+import random
+
+import numpy as np
+
+from torch.utils.data import DistributedSampler
+from torch.utils.data.sampler import Sampler
+
+LOGGER = logging.getLogger(__name__)
+
+from torch.utils.data import Dataset, Sampler
+
+
+class DatasetFromSampler(Dataset):
+ """Dataset to create indexes from `Sampler`.
+ Args:
+ sampler: PyTorch sampler
+ """
+
+ def __init__(self, sampler: Sampler):
+ """Initialisation for DatasetFromSampler."""
+ self.sampler = sampler
+ self.sampler_list = None
+
+ def __getitem__(self, index: int):
+ """Gets element of the dataset.
+ Args:
+ index: index of the element in the dataset
+ Returns:
+ Single element by index
+ """
+ if self.sampler_list is None:
+ self.sampler_list = list(self.sampler)
+ return self.sampler_list[index]
+
+ def __len__(self) -> int:
+ """
+ Returns:
+ int: length of the dataset
+ """
+ return len(self.sampler)
+
+
+class BalanceClassSampler(Sampler):
+ """Allows you to create stratified sample on unbalanced classes.
+
+ Args:
+ labels: list of class label for each elem in the dataset
+ mode: Strategy to balance classes.
+ Must be one of [downsampling, upsampling]
+
+ Python API examples:
+
+ .. code-block:: python
+
+ import os
+ from torch import nn, optim
+ from torch.utils.data import DataLoader
+ from catalyst import dl
+ from catalyst.data import ToTensor, BalanceClassSampler
+ from catalyst.contrib.datasets import MNIST
+
+ train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
+ train_labels = train_data.targets.cpu().numpy().tolist()
+ train_sampler = BalanceClassSampler(train_labels, mode=5000)
+ valid_data = MNIST(os.getcwd(), train=False)
+
+ loaders = {
+ "train": DataLoader(train_data, sampler=train_sampler, batch_size=32),
+ "valid": DataLoader(valid_data, batch_size=32),
+ }
+
+ model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
+ criterion = nn.CrossEntropyLoss()
+ optimizer = optim.Adam(model.parameters(), lr=0.02)
+
+ runner = dl.SupervisedRunner()
+ # model training
+ runner.train(
+ model=model,
+ criterion=criterion,
+ optimizer=optimizer,
+ loaders=loaders,
+ num_epochs=1,
+ logdir="./logs",
+ valid_loader="valid",
+ valid_metric="loss",
+ minimize_valid_metric=True,
+ verbose=True,
+ )
+ """
+
+ def __init__(self, labels: List[int], mode: Union[str, int] = "downsampling"):
+ """Sampler initialisation."""
+ super().__init__(labels)
+
+ labels = np.array(labels)
+ samples_per_class = {label: (labels == label).sum() for label in set(labels)}
+
+ self.lbl2idx = {
+ label: np.arange(len(labels))[labels == label].tolist()
+ for label in set(labels)
+ }
+
+ if isinstance(mode, str):
+ assert mode in ["downsampling", "upsampling"]
+
+ if isinstance(mode, int) or mode == "upsampling":
+ samples_per_class = (
+ mode if isinstance(mode, int) else max(samples_per_class.values())
+ )
+ else:
+ samples_per_class = min(samples_per_class.values())
+
+ self.labels = labels
+ self.samples_per_class = samples_per_class
+ self.length = self.samples_per_class * len(set(labels))
+
+ def __iter__(self) -> Iterator[int]:
+ """
+ Returns:
+ iterator of indices of stratified sample
+ """
+ indices = []
+ for key in sorted(self.lbl2idx):
+ replace_flag = self.samples_per_class > len(self.lbl2idx[key])
+ indices += np.random.choice(
+ self.lbl2idx[key], self.samples_per_class, replace=replace_flag
+ ).tolist()
+ assert len(indices) == self.length
+ np.random.shuffle(indices)
+
+ return iter(indices)
+
+ def __len__(self) -> int:
+ """
+ Returns:
+ length of result sample
+ """
+ return self.length
+
+
+class BatchBalanceClassSampler(Sampler):
+ """
+ This kind of sampler can be used for both metric learning and classification task.
+
+ BatchSampler with the given strategy for the C unique classes dataset:
+ - Selection `num_classes` of C classes for each batch
+ - Selection `num_samples` instances for each class in the batch
+ The epoch ends after `num_batches`.
+ So, the batch sise is `num_classes` * `num_samples`.
+
+ One of the purposes of this sampler is to be used for
+ forming triplets and pos/neg pairs inside the batch.
+ To guarante existance of these pairs in the batch,
+ `num_classes` and `num_samples` should be > 1. (1)
+
+ This type of sampling can be found in the classical paper of Person Re-Id,
+ where P (`num_classes`) equals 32 and K (`num_samples`) equals 4:
+ `In Defense of the Triplet Loss for Person Re-Identification`_.
+
+ Args:
+ labels: list of classes labeles for each elem in the dataset
+ num_classes: number of classes in a batch, should be > 1
+ num_samples: number of instances of each class in a batch, should be > 1
+ num_batches: number of batches in epoch
+ (default = len(labels) // (num_classes * num_samples))
+
+ .. _In Defense of the Triplet Loss for Person Re-Identification:
+ https://arxiv.org/abs/1703.07737
+
+ Python API examples:
+
+ .. code-block:: python
+
+ import os
+ from torch import nn, optim
+ from torch.utils.data import DataLoader
+ from catalyst import dl
+ from catalyst.data import ToTensor, BatchBalanceClassSampler
+ from catalyst.contrib.datasets import MNIST
+
+ train_data = MNIST(os.getcwd(), train=True, download=True)
+ train_labels = train_data.targets.cpu().numpy().tolist()
+ train_sampler = BatchBalanceClassSampler(
+ train_labels, num_classes=10, num_samples=4)
+ valid_data = MNIST(os.getcwd(), train=False)
+
+ loaders = {
+ "train": DataLoader(train_data, batch_sampler=train_sampler),
+ "valid": DataLoader(valid_data, batch_size=32),
+ }
+
+ model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
+ criterion = nn.CrossEntropyLoss()
+ optimizer = optim.Adam(model.parameters(), lr=0.02)
+
+ runner = dl.SupervisedRunner()
+ # model training
+ runner.train(
+ model=model,
+ criterion=criterion,
+ optimizer=optimizer,
+ loaders=loaders,
+ num_epochs=1,
+ logdir="./logs",
+ valid_loader="valid",
+ valid_metric="loss",
+ minimize_valid_metric=True,
+ verbose=True,
+ )
+ """
+
+ def __init__(
+ self,
+ labels: Union[List[int], np.ndarray],
+ num_classes: int,
+ num_samples: int,
+ num_batches: int = None,
+ ):
+ """Sampler initialisation."""
+ super().__init__(labels)
+ classes = set(labels)
+
+ assert isinstance(num_classes, int) and isinstance(num_samples, int)
+ assert (1 < num_classes <= len(classes)) and (1 < num_samples)
+ assert all(
+ n > 1 for n in Counter(labels).values()
+ ), "Each class shoud contain at least 2 instances to fit (1)"
+
+ labels = np.array(labels)
+ self._labels = list(set(labels.tolist()))
+ self._num_classes = num_classes
+ self._num_samples = num_samples
+ self._batch_size = self._num_classes * self._num_samples
+ self._num_batches = num_batches or len(labels) // self._batch_size
+ self.lbl2idx = {
+ label: np.arange(len(labels))[labels == label].tolist()
+ for label in set(labels)
+ }
+
+ @property
+ def batch_size(self) -> int:
+ """
+ Returns:
+ this value should be used in DataLoader as batch size
+ """
+ return self._batch_size
+
+ @property
+ def batches_in_epoch(self) -> int:
+ """
+ Returns:
+ number of batches in an epoch
+ """
+ return self._num_batches
+
+ def __len__(self) -> int:
+ """
+ Returns:
+ number of samples in an epoch
+ """
+ return self._num_batches # * self._batch_size
+
+ def __iter__(self) -> Iterator[int]:
+ """
+ Returns:
+ indeces for sampling dataset elems during an epoch
+ """
+ indices = []
+ for _ in range(self._num_batches):
+ batch_indices = []
+ classes_for_batch = random.sample(self._labels, self._num_classes)
+ while self._num_classes != len(set(classes_for_batch)):
+ classes_for_batch = random.sample(self._labels, self._num_classes)
+ for cls_id in classes_for_batch:
+ replace_flag = self._num_samples > len(self.lbl2idx[cls_id])
+ batch_indices += np.random.choice(
+ self.lbl2idx[cls_id], self._num_samples, replace=replace_flag
+ ).tolist()
+ indices.append(batch_indices)
+ return iter(indices)
+
+
+class DynamicBalanceClassSampler(Sampler):
+ """
+ This kind of sampler can be used for classification tasks with significant
+ class imbalance.
+
+ The idea of this sampler that we start with the original class distribution
+ and gradually move to uniform class distribution like with downsampling.
+
+ Let's define :math: D_i = #C_i/ #C_min where :math: #C_i is a size of class
+ i and :math: #C_min is a size of the rarest class, so :math: D_i define
+ class distribution. Also define :math: g(n_epoch) is a exponential
+ scheduler. On each epoch current :math: D_i calculated as
+ :math: current D_i = D_i ^ g(n_epoch),
+ after this data samples according this distribution.
+
+ Notes:
+ In the end of the training, epochs will contain only
+ min_size_class * n_classes examples. So, possible it will not
+ necessary to do validation on each epoch. For this reason use
+ ControlFlowCallback.
+
+ Examples:
+
+ >>> import torch
+ >>> import numpy as np
+
+ >>> from catalyst.data import DynamicBalanceClassSampler
+ >>> from torch.utils import data
+
+ >>> features = torch.Tensor(np.random.random((200, 100)))
+ >>> labels = np.random.randint(0, 4, size=(200,))
+ >>> sampler = DynamicBalanceClassSampler(labels)
+ >>> labels = torch.LongTensor(labels)
+ >>> dataset = data.TensorDataset(features, labels)
+ >>> loader = data.dataloader.DataLoader(dataset, batch_size=8)
+
+ >>> for batch in loader:
+ >>> b_features, b_labels = batch
+
+ Sampler was inspired by https://arxiv.org/abs/1901.06783
+ """
+
+ def __init__(
+ self,
+ labels: List[Union[int, str]],
+ exp_lambda: float = 0.9,
+ start_epoch: int = 0,
+ max_d: Optional[int] = None,
+ mode: Union[str, int] = "downsampling",
+ ignore_warning: bool = False,
+ ):
+ """
+ Args:
+ labels: list of labels for each elem in the dataset
+ exp_lambda: exponent figure for schedule
+ start_epoch: start epoch number, can be useful for multi-stage
+ experiments
+ max_d: if not None, limit on the difference between the most
+ frequent and the rarest classes, heuristic
+ mode: number of samples per class in the end of training. Must be
+ "downsampling" or number. Before change it, make sure that you
+ understand how does it work
+ ignore_warning: ignore warning about min class size
+ """
+ assert isinstance(start_epoch, int)
+ assert 0 < exp_lambda < 1, "exp_lambda must be in (0, 1)"
+ super().__init__(labels)
+ self.exp_lambda = exp_lambda
+ if max_d is None:
+ max_d = np.inf
+ self.max_d = max_d
+ self.epoch = start_epoch
+ labels = np.array(labels)
+ samples_per_class = Counter(labels)
+ self.min_class_size = min(samples_per_class.values())
+
+ if self.min_class_size < 100 and not ignore_warning:
+ LOGGER.warning(
+ f"the smallest class contains only"
+ f" {self.min_class_size} examples. At the end of"
+ f" training, epochs will contain only"
+ f" {self.min_class_size * len(samples_per_class)}"
+ f" examples"
+ )
+
+ self.original_d = {
+ key: value / self.min_class_size for key, value in samples_per_class.items()
+ }
+ self.label2idxes = {
+ label: np.arange(len(labels))[labels == label].tolist()
+ for label in set(labels)
+ }
+
+ if isinstance(mode, int):
+ self.min_class_size = mode
+ else:
+ assert mode == "downsampling"
+
+ self.labels = labels
+ self._update()
+
+ def _update(self) -> None:
+ """Update d coefficients."""
+ current_d = {
+ key: min(value ** self._exp_scheduler(), self.max_d)
+ for key, value in self.original_d.items()
+ }
+ samples_per_classes = {
+ key: int(value * self.min_class_size) for key, value in current_d.items()
+ }
+ self.samples_per_classes = samples_per_classes
+ self.length = np.sum(list(samples_per_classes.values()))
+ self.epoch += 1
+
+ def _exp_scheduler(self) -> float:
+ return self.exp_lambda**self.epoch
+
+ def __iter__(self) -> Iterator[int]:
+ """
+ Returns:
+ iterator of indices of stratified sample
+ """
+ indices = []
+ for key in sorted(self.label2idxes):
+ samples_per_class = self.samples_per_classes[key]
+ replace_flag = samples_per_class > len(self.label2idxes[key])
+ indices += np.random.choice(
+ self.label2idxes[key], samples_per_class, replace=replace_flag
+ ).tolist()
+ assert len(indices) == self.length
+ np.random.shuffle(indices)
+ self._update()
+ return iter(indices)
+
+ def __len__(self) -> int:
+ """
+ Returns:
+ length of result sample
+ """
+ return self.length
+
+
+class MiniEpochSampler(Sampler):
+ """
+ Sampler iterates mini epochs from the dataset used by ``mini_epoch_len``.
+
+ Args:
+ data_len: Size of the dataset
+ mini_epoch_len: Num samples from the dataset used in one
+ mini epoch.
+ drop_last: If ``True``, sampler will drop the last batches
+ if its size would be less than ``batches_per_epoch``
+ shuffle: one of ``"always"``, ``"real_epoch"``, or `None``.
+ The sampler will shuffle indices
+ > "per_mini_epoch" - every mini epoch (every ``__iter__`` call)
+ > "per_epoch" -- every real epoch
+ > None -- don't shuffle
+
+ Example:
+ >>> MiniEpochSampler(len(dataset), mini_epoch_len=100)
+ >>> MiniEpochSampler(len(dataset), mini_epoch_len=100, drop_last=True)
+ >>> MiniEpochSampler(len(dataset), mini_epoch_len=100,
+ >>> shuffle="per_epoch")
+ """
+
+ def __init__(
+ self,
+ data_len: int,
+ mini_epoch_len: int,
+ drop_last: bool = False,
+ shuffle: str = None,
+ ):
+ """Sampler initialisation."""
+ super().__init__(None)
+
+ self.data_len = int(data_len)
+ self.mini_epoch_len = int(mini_epoch_len)
+
+ self.steps = int(data_len / self.mini_epoch_len)
+ self.state_i = 0
+
+ has_reminder = data_len - self.steps * mini_epoch_len > 0
+ if self.steps == 0:
+ self.divider = 1
+ elif has_reminder and not drop_last:
+ self.divider = self.steps + 1
+ else:
+ self.divider = self.steps
+
+ self._indices = np.arange(self.data_len)
+ self.indices = self._indices
+ self.end_pointer = max(self.data_len, self.mini_epoch_len)
+
+ if not (shuffle is None or shuffle in ["per_mini_epoch", "per_epoch"]):
+ raise ValueError(
+ "Shuffle must be one of ['per_mini_epoch', 'per_epoch']. "
+ + f"Got {shuffle}"
+ )
+ self.shuffle_type = shuffle
+
+ def shuffle(self) -> None:
+ """Shuffle sampler indices."""
+ if self.shuffle_type == "per_mini_epoch" or (
+ self.shuffle_type == "per_epoch" and self.state_i == 0
+ ):
+ if self.data_len >= self.mini_epoch_len:
+ self.indices = self._indices
+ np.random.shuffle(self.indices)
+ else:
+ self.indices = np.random.choice(
+ self._indices, self.mini_epoch_len, replace=True
+ )
+
+ def __iter__(self) -> Iterator[int]:
+ """Iterate over sampler.
+
+ Returns:
+ python iterator
+ """
+ self.state_i = self.state_i % self.divider
+ self.shuffle()
+
+ start = self.state_i * self.mini_epoch_len
+ stop = (
+ self.end_pointer
+ if (self.state_i == self.steps)
+ else (self.state_i + 1) * self.mini_epoch_len
+ )
+ indices = self.indices[start:stop].tolist()
+
+ self.state_i += 1
+ return iter(indices)
+
+ def __len__(self) -> int:
+ """
+ Returns:
+ int: length of the mini-epoch
+ """
+ return self.mini_epoch_len
+
+
+class DistributedSamplerWrapper(DistributedSampler):
+ """
+ Wrapper over `Sampler` for distributed training.
+ Allows you to use any sampler in distributed mode.
+
+ It is especially useful in conjunction with
+ `torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSamplerWrapper instance as a DataLoader
+ sampler, and load a subset of subsampled data of the original dataset
+ that is exclusive to it.
+
+ .. note::
+ Sampler is assumed to be of constant size.
+ """
+
+ def __init__(
+ self,
+ sampler,
+ num_replicas: Optional[int] = None,
+ rank: Optional[int] = None,
+ shuffle: bool = True,
+ ):
+ """
+
+ Args:
+ sampler: Sampler used for subsampling
+ num_replicas (int, optional): Number of processes participating in
+ distributed training
+ rank (int, optional): Rank of the current process
+ within ``num_replicas``
+ shuffle (bool, optional): If true (default),
+ sampler will shuffle the indices
+ """
+ super(DistributedSamplerWrapper, self).__init__(
+ DatasetFromSampler(sampler),
+ num_replicas=num_replicas,
+ rank=rank,
+ shuffle=shuffle,
+ )
+ self.sampler = sampler
+
+ def __iter__(self) -> Iterator[int]:
+ """Iterate over sampler.
+
+ Returns:
+ python iterator
+ """
+ self.dataset = DatasetFromSampler(self.sampler)
+ indexes_of_indexes = super().__iter__()
+ subsampler_indexes = self.dataset
+ return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
+
+
+__all__ = [
+ "BalanceClassSampler",
+ "BatchBalanceClassSampler",
+ "DistributedSamplerWrapper",
+ "DynamicBalanceClassSampler",
+ "MiniEpochSampler",
+]
diff --git a/audioldm2/utilities/tools.py b/audioldm2/utilities/tools.py
new file mode 100755
index 0000000000000000000000000000000000000000..a647a272cdf076b2ae9785bc83724ebd7a897642
--- /dev/null
+++ b/audioldm2/utilities/tools.py
@@ -0,0 +1,541 @@
+# Author: Haohe Liu
+# Email: haoheliu@gmail.com
+# Date: 11 Feb 2023
+
+import os
+import json
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+import matplotlib
+from scipy.io import wavfile
+from matplotlib import pyplot as plt
+
+
+matplotlib.use("Agg")
+
+import hashlib
+import os
+
+import requests
+from tqdm import tqdm
+
+URL_MAP = {
+ "vggishish_lpaps": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt",
+ "vggishish_mean_std_melspec_10s_22050hz": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt",
+ "melception": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt",
+}
+
+CKPT_MAP = {
+ "vggishish_lpaps": "vggishish16.pt",
+ "vggishish_mean_std_melspec_10s_22050hz": "train_means_stds_melspec_10s_22050hz.txt",
+ "melception": "melception-21-05-10T09-28-40.pt",
+}
+
+MD5_MAP = {
+ "vggishish_lpaps": "197040c524a07ccacf7715d7080a80bd",
+ "vggishish_mean_std_melspec_10s_22050hz": "f449c6fd0e248936c16f6d22492bb625",
+ "melception": "a71a41041e945b457c7d3d814bbcf72d",
+}
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+def load_json(fname):
+ with open(fname, "r") as f:
+ data = json.load(f)
+ return data
+
+
+def read_json(dataset_json_file):
+ with open(dataset_json_file, "r") as fp:
+ data_json = json.load(fp)
+ return data_json["data"]
+
+
+def copy_test_subset_data(metadata, testset_copy_target_path):
+ # metadata = read_json(testset_metadata)
+ os.makedirs(testset_copy_target_path, exist_ok=True)
+ if len(os.listdir(testset_copy_target_path)) == len(metadata):
+ return
+ else:
+ # delete files in folder testset_copy_target_path
+ for file in os.listdir(testset_copy_target_path):
+ try:
+ os.remove(os.path.join(testset_copy_target_path, file))
+ except Exception as e:
+ print(e)
+
+ print("Copying test subset data to {}".format(testset_copy_target_path))
+ for each in tqdm(metadata):
+ cmd = "cp {} {}".format(each["wav"], os.path.join(testset_copy_target_path))
+ os.system(cmd)
+
+
+def listdir_nohidden(path):
+ for f in os.listdir(path):
+ if not f.startswith("."):
+ yield f
+
+
+def get_restore_step(path):
+ checkpoints = os.listdir(path)
+ if os.path.exists(os.path.join(path, "final.ckpt")):
+ return "final.ckpt", 0
+ elif not os.path.exists(os.path.join(path, "last.ckpt")):
+ steps = [int(x.split(".ckpt")[0].split("step=")[1]) for x in checkpoints]
+ return checkpoints[np.argmax(steps)], np.max(steps)
+ else:
+ steps = []
+ for x in checkpoints:
+ if "last" in x:
+ if "-v" not in x:
+ fname = "last.ckpt"
+ else:
+ this_version = int(x.split(".ckpt")[0].split("-v")[1])
+ steps.append(this_version)
+ if len(steps) == 0 or this_version > np.max(steps):
+ fname = "last-v%s.ckpt" % this_version
+ return fname, 0
+
+
+def download(url, local_path, chunk_size=1024):
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
+ with requests.get(url, stream=True) as r:
+ total_size = int(r.headers.get("content-length", 0))
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
+ with open(local_path, "wb") as f:
+ for data in r.iter_content(chunk_size=chunk_size):
+ if data:
+ f.write(data)
+ pbar.update(chunk_size)
+
+
+def md5_hash(path):
+ with open(path, "rb") as f:
+ content = f.read()
+ return hashlib.md5(content).hexdigest()
+
+
+def get_ckpt_path(name, root, check=False):
+ assert name in URL_MAP
+ path = os.path.join(root, CKPT_MAP[name])
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
+ download(URL_MAP[name], path)
+ md5 = md5_hash(path)
+ assert md5 == MD5_MAP[name], md5
+ return path
+
+
+class KeyNotFoundError(Exception):
+ def __init__(self, cause, keys=None, visited=None):
+ self.cause = cause
+ self.keys = keys
+ self.visited = visited
+ messages = list()
+ if keys is not None:
+ messages.append("Key not found: {}".format(keys))
+ if visited is not None:
+ messages.append("Visited: {}".format(visited))
+ messages.append("Cause:\n{}".format(cause))
+ message = "\n".join(messages)
+ super().__init__(message)
+
+
+def retrieve(
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
+):
+ """Given a nested list or dict return the desired value at key expanding
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
+ is done in-place.
+
+ Parameters
+ ----------
+ list_or_dict : list or dict
+ Possibly nested list or dictionary.
+ key : str
+ key/to/value, path like string describing all keys necessary to
+ consider to get to the desired value. List indices can also be
+ passed here.
+ splitval : str
+ String that defines the delimiter between keys of the
+ different depth levels in `key`.
+ default : obj
+ Value returned if :attr:`key` is not found.
+ expand : bool
+ Whether to expand callable nodes on the path or not.
+
+ Returns
+ -------
+ The desired value or if :attr:`default` is not ``None`` and the
+ :attr:`key` is not found returns ``default``.
+
+ Raises
+ ------
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
+ ``None``.
+ """
+
+ keys = key.split(splitval)
+
+ success = True
+ try:
+ visited = []
+ parent = None
+ last_key = None
+ for key in keys:
+ if callable(list_or_dict):
+ if not expand:
+ raise KeyNotFoundError(
+ ValueError(
+ "Trying to get past callable node with expand=False."
+ ),
+ keys=keys,
+ visited=visited,
+ )
+ list_or_dict = list_or_dict()
+ parent[last_key] = list_or_dict
+
+ last_key = key
+ parent = list_or_dict
+
+ try:
+ if isinstance(list_or_dict, dict):
+ list_or_dict = list_or_dict[key]
+ else:
+ list_or_dict = list_or_dict[int(key)]
+ except (KeyError, IndexError, ValueError) as e:
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
+
+ visited += [key]
+ # final expansion of retrieved value
+ if expand and callable(list_or_dict):
+ list_or_dict = list_or_dict()
+ parent[last_key] = list_or_dict
+ except KeyNotFoundError as e:
+ if default is None:
+ raise e
+ else:
+ list_or_dict = default
+ success = False
+
+ if not pass_success:
+ return list_or_dict
+ else:
+ return list_or_dict, success
+
+
+def to_device(data, device):
+ if len(data) == 12:
+ (
+ ids,
+ raw_texts,
+ speakers,
+ texts,
+ src_lens,
+ max_src_len,
+ mels,
+ mel_lens,
+ max_mel_len,
+ pitches,
+ energies,
+ durations,
+ ) = data
+
+ speakers = torch.from_numpy(speakers).long().to(device)
+ texts = torch.from_numpy(texts).long().to(device)
+ src_lens = torch.from_numpy(src_lens).to(device)
+ mels = torch.from_numpy(mels).float().to(device)
+ mel_lens = torch.from_numpy(mel_lens).to(device)
+ pitches = torch.from_numpy(pitches).float().to(device)
+ energies = torch.from_numpy(energies).to(device)
+ durations = torch.from_numpy(durations).long().to(device)
+
+ return (
+ ids,
+ raw_texts,
+ speakers,
+ texts,
+ src_lens,
+ max_src_len,
+ mels,
+ mel_lens,
+ max_mel_len,
+ pitches,
+ energies,
+ durations,
+ )
+
+ if len(data) == 6:
+ (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data
+
+ speakers = torch.from_numpy(speakers).long().to(device)
+ texts = torch.from_numpy(texts).long().to(device)
+ src_lens = torch.from_numpy(src_lens).to(device)
+
+ return (ids, raw_texts, speakers, texts, src_lens, max_src_len)
+
+
+def log(logger, step=None, fig=None, audio=None, sampling_rate=22050, tag=""):
+ # if losses is not None:
+ # logger.add_scalar("Loss/total_loss", losses[0], step)
+ # logger.add_scalar("Loss/mel_loss", losses[1], step)
+ # logger.add_scalar("Loss/mel_postnet_loss", losses[2], step)
+ # logger.add_scalar("Loss/pitch_loss", losses[3], step)
+ # logger.add_scalar("Loss/energy_loss", losses[4], step)
+ # logger.add_scalar("Loss/duration_loss", losses[5], step)
+ # if(len(losses) > 6):
+ # logger.add_scalar("Loss/disc_loss", losses[6], step)
+ # logger.add_scalar("Loss/fmap_loss", losses[7], step)
+ # logger.add_scalar("Loss/r_loss", losses[8], step)
+ # logger.add_scalar("Loss/g_loss", losses[9], step)
+ # logger.add_scalar("Loss/gen_loss", losses[10], step)
+ # logger.add_scalar("Loss/diff_loss", losses[11], step)
+
+ if fig is not None:
+ logger.add_figure(tag, fig)
+
+ if audio is not None:
+ audio = audio / (max(abs(audio)) * 1.1)
+ logger.add_audio(
+ tag,
+ audio,
+ sample_rate=sampling_rate,
+ )
+
+
+def get_mask_from_lengths(lengths, max_len=None):
+ batch_size = lengths.shape[0]
+ if max_len is None:
+ max_len = torch.max(lengths).item()
+
+ ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
+ mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
+
+ return mask
+
+
+def expand(values, durations):
+ out = list()
+ for value, d in zip(values, durations):
+ out += [value] * max(0, int(d))
+ return np.array(out)
+
+
+def synth_one_sample_val(
+ targets, predictions, vocoder, model_config, preprocess_config
+):
+ index = np.random.choice(list(np.arange(targets[6].size(0))))
+
+ basename = targets[0][index]
+ src_len = predictions[8][index].item()
+ mel_len = predictions[9][index].item()
+ mel_target = targets[6][index, :mel_len].detach().transpose(0, 1)
+
+ mel_prediction = predictions[0][index, :mel_len].detach().transpose(0, 1)
+ postnet_mel_prediction = predictions[1][index, :mel_len].detach().transpose(0, 1)
+ duration = targets[11][index, :src_len].detach().cpu().numpy()
+
+ if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
+ pitch = predictions[2][index, :src_len].detach().cpu().numpy()
+ pitch = expand(pitch, duration)
+ else:
+ pitch = predictions[2][index, :mel_len].detach().cpu().numpy()
+
+ if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
+ energy = predictions[3][index, :src_len].detach().cpu().numpy()
+ energy = expand(energy, duration)
+ else:
+ energy = predictions[3][index, :mel_len].detach().cpu().numpy()
+
+ with open(
+ os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
+ ) as f:
+ stats = json.load(f)
+ stats = stats["pitch"] + stats["energy"][:2]
+
+ # from datetime import datetime
+ # now = datetime.now()
+ # current_time = now.strftime("%D:%H:%M:%S")
+ # np.save(("mel_pred_%s.npy" % current_time).replace("/","-"), mel_prediction.cpu().numpy())
+ # np.save(("postnet_mel_prediction_%s.npy" % current_time).replace("/","-"), postnet_mel_prediction.cpu().numpy())
+ # np.save(("mel_target_%s.npy" % current_time).replace("/","-"), mel_target.cpu().numpy())
+
+ fig = plot_mel(
+ [
+ (mel_prediction.cpu().numpy(), pitch, energy),
+ (postnet_mel_prediction.cpu().numpy(), pitch, energy),
+ (mel_target.cpu().numpy(), pitch, energy),
+ ],
+ stats,
+ [
+ "Raw mel spectrogram prediction",
+ "Postnet mel prediction",
+ "Ground-Truth Spectrogram",
+ ],
+ )
+
+ if vocoder is not None:
+ from .model import vocoder_infer
+
+ wav_reconstruction = vocoder_infer(
+ mel_target.unsqueeze(0),
+ vocoder,
+ model_config,
+ preprocess_config,
+ )[0]
+ wav_prediction = vocoder_infer(
+ postnet_mel_prediction.unsqueeze(0),
+ vocoder,
+ model_config,
+ preprocess_config,
+ )[0]
+ else:
+ wav_reconstruction = wav_prediction = None
+
+ return fig, wav_reconstruction, wav_prediction, basename
+
+
+def synth_one_sample(mel_input, mel_prediction, labels, vocoder):
+ if vocoder is not None:
+ from .model import vocoder_infer
+
+ wav_reconstruction = vocoder_infer(
+ mel_input.permute(0, 2, 1),
+ vocoder,
+ )
+ wav_prediction = vocoder_infer(
+ mel_prediction.permute(0, 2, 1),
+ vocoder,
+ )
+ else:
+ wav_reconstruction = wav_prediction = None
+
+ return wav_reconstruction, wav_prediction
+
+
+def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path):
+ # (diff_output, diff_loss, latent_loss) = diffusion
+
+ basenames = targets[0]
+
+ for i in range(len(predictions[1])):
+ basename = basenames[i]
+ src_len = predictions[8][i].item()
+ mel_len = predictions[9][i].item()
+ mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1)
+ # diff_output = diff_output[i, :mel_len].detach().transpose(0, 1)
+ # duration = predictions[5][i, :src_len].detach().cpu().numpy()
+ if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
+ pitch = predictions[2][i, :src_len].detach().cpu().numpy()
+ # pitch = expand(pitch, duration)
+ else:
+ pitch = predictions[2][i, :mel_len].detach().cpu().numpy()
+ if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
+ energy = predictions[3][i, :src_len].detach().cpu().numpy()
+ # energy = expand(energy, duration)
+ else:
+ energy = predictions[3][i, :mel_len].detach().cpu().numpy()
+ # import ipdb; ipdb.set_trace()
+ with open(
+ os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
+ ) as f:
+ stats = json.load(f)
+ stats = stats["pitch"] + stats["energy"][:2]
+
+ fig = plot_mel(
+ [
+ (mel_prediction.cpu().numpy(), pitch, energy),
+ ],
+ stats,
+ ["Synthetized Spectrogram by PostNet"],
+ )
+ # np.save("{}_postnet.npy".format(basename), mel_prediction.cpu().numpy())
+ plt.savefig(os.path.join(path, "{}_postnet_2.png".format(basename)))
+ plt.close()
+
+ from .model import vocoder_infer
+
+ mel_predictions = predictions[1].transpose(1, 2)
+ lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"]
+ wav_predictions = vocoder_infer(
+ mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths
+ )
+
+ sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
+ for wav, basename in zip(wav_predictions, basenames):
+ wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav)
+
+
+def plot_mel(data, titles=None):
+ fig, axes = plt.subplots(len(data), 1, squeeze=False)
+ if titles is None:
+ titles = [None for i in range(len(data))]
+
+ for i in range(len(data)):
+ mel = data[i]
+ axes[i][0].imshow(mel, origin="lower", aspect="auto")
+ axes[i][0].set_aspect(2.5, adjustable="box")
+ axes[i][0].set_ylim(0, mel.shape[0])
+ axes[i][0].set_title(titles[i], fontsize="medium")
+ axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
+ axes[i][0].set_anchor("W")
+
+ return fig
+
+
+def pad_1D(inputs, PAD=0):
+ def pad_data(x, length, PAD):
+ x_padded = np.pad(
+ x, (0, length - x.shape[0]), mode="constant", constant_values=PAD
+ )
+ return x_padded
+
+ max_len = max((len(x) for x in inputs))
+ padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])
+
+ return padded
+
+
+def pad_2D(inputs, maxlen=None):
+ def pad(x, max_len):
+ PAD = 0
+ if np.shape(x)[0] > max_len:
+ raise ValueError("not max_len")
+
+ s = np.shape(x)[1]
+ x_padded = np.pad(
+ x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD
+ )
+ return x_padded[:, :s]
+
+ if maxlen:
+ output = np.stack([pad(x, maxlen) for x in inputs])
+ else:
+ max_len = max(np.shape(x)[0] for x in inputs)
+ output = np.stack([pad(x, max_len) for x in inputs])
+
+ return output
+
+
+def pad(input_ele, mel_max_length=None):
+ if mel_max_length:
+ max_len = mel_max_length
+ else:
+ max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
+
+ out_list = list()
+ for i, batch in enumerate(input_ele):
+ if len(batch.shape) == 1:
+ one_batch_padded = F.pad(
+ batch, (0, max_len - batch.size(0)), "constant", 0.0
+ )
+ elif len(batch.shape) == 2:
+ one_batch_padded = F.pad(
+ batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
+ )
+ out_list.append(one_batch_padded)
+ out_padded = torch.stack(out_list)
+ return out_padded
diff --git a/audioldm2/utils.py b/audioldm2/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..c098e25b99cf78b0c0befc71fb7ba7688e79c899
--- /dev/null
+++ b/audioldm2/utils.py
@@ -0,0 +1,352 @@
+import contextlib
+import importlib
+from huggingface_hub import hf_hub_download
+
+from inspect import isfunction
+import os
+import soundfile as sf
+import time
+import wave
+
+import progressbar
+
+CACHE_DIR = os.getenv(
+ "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm2")
+)
+
+def read_list(fname):
+ result = []
+ with open(fname, "r", encoding="utf-8") as f:
+ for each in f.readlines():
+ each = each.strip('\n')
+ result.append(each)
+ return result
+
+def get_duration(fname):
+ with contextlib.closing(wave.open(fname, "r")) as f:
+ frames = f.getnframes()
+ rate = f.getframerate()
+ return frames / float(rate)
+
+
+def get_bit_depth(fname):
+ with contextlib.closing(wave.open(fname, "r")) as f:
+ bit_depth = f.getsampwidth() * 8
+ return bit_depth
+
+
+def get_time():
+ t = time.localtime()
+ return time.strftime("%d_%m_%Y_%H_%M_%S", t)
+
+
+def seed_everything(seed):
+ import random, os
+ import numpy as np
+ import torch
+
+ random.seed(seed)
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = True
+
+
+def save_wave(waveform, savepath, name="outwav"):
+ if type(name) is not list:
+ name = [name] * waveform.shape[0]
+
+ for i in range(waveform.shape[0]):
+ path = os.path.join(
+ savepath,
+ "%s_%s.wav"
+ % (
+ os.path.basename(name[i])
+ if (not ".wav" in name[i])
+ else os.path.basename(name[i]).split(".")[0],
+ i,
+ ),
+ )
+ print("Save audio to %s" % path)
+ sf.write(path, waveform[i, 0], samplerate=16000)
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == "__is_first_stage__":
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ try:
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+ except:
+ import ipdb
+
+ ipdb.set_trace()
+
+
+def default_audioldm_config(model_name="audioldm2-full"):
+ basic_config = {
+ "metadata_root": "/mnt/bn/lqhaoheliu/metadata/processed/dataset_root.json",
+ "log_directory": "./log/audiomae_pred",
+ "precision": "high",
+ "data": {
+ "train": [
+ "audiocaps",
+ "audioset",
+ "wavcaps",
+ "audiostock_music_250k",
+ "free_to_use_sounds",
+ "epidemic_sound_effects",
+ "vggsound",
+ "million_song_dataset",
+ ],
+ "val": "audiocaps",
+ "test": "audiocaps",
+ "class_label_indices": "audioset",
+ "dataloader_add_ons": [
+ "extract_kaldi_fbank_feature",
+ "extract_vits_phoneme_and_flant5_text",
+ "waveform_rs_48k",
+ ],
+ },
+ "variables": {
+ "sampling_rate": 16000,
+ "mel_bins": 64,
+ "latent_embed_dim": 8,
+ "latent_t_size": 256,
+ "latent_f_size": 16,
+ "in_channels": 8,
+ "optimize_ddpm_parameter": True,
+ "warmup_steps": 5000,
+ },
+ "step": {
+ "validation_every_n_epochs": 1,
+ "save_checkpoint_every_n_steps": 5000,
+ "limit_val_batches": 10,
+ "max_steps": 1500000,
+ "save_top_k": 2,
+ },
+ "preprocessing": {
+ "audio": {
+ "sampling_rate": 16000,
+ "max_wav_value": 32768,
+ "duration": 10.24,
+ },
+ "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024},
+ "mel": {"n_mel_channels": 64, "mel_fmin": 0, "mel_fmax": 8000},
+ },
+ "augmentation": {"mixup": 0},
+ "model": {
+ "target": "audioldm2.latent_diffusion.models.ddpm.LatentDiffusion",
+ "params": {
+ "first_stage_config": {
+ "base_learning_rate": 0.000008,
+ "target": "audioldm2.latent_encoder.autoencoder.AutoencoderKL",
+ "params": {
+ "sampling_rate": 16000,
+ "batchsize": 4,
+ "monitor": "val/rec_loss",
+ "image_key": "fbank",
+ "subband": 1,
+ "embed_dim": 8,
+ "time_shuffle": 1,
+ "lossconfig": {
+ "target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator",
+ "params": {
+ "disc_start": 50001,
+ "kl_weight": 1000,
+ "disc_weight": 0.5,
+ "disc_in_channels": 1,
+ },
+ },
+ "ddconfig": {
+ "double_z": True,
+ "mel_bins": 64,
+ "z_channels": 8,
+ "resolution": 256,
+ "downsample_time": False,
+ "in_channels": 1,
+ "out_ch": 1,
+ "ch": 128,
+ "ch_mult": [1, 2, 4],
+ "num_res_blocks": 2,
+ "attn_resolutions": [],
+ "dropout": 0,
+ },
+ },
+ },
+ "base_learning_rate": 0.0001,
+ "warmup_steps": 5000,
+ "optimize_ddpm_parameter": True,
+ "sampling_rate": 16000,
+ "batchsize": 16,
+ "linear_start": 0.0015,
+ "linear_end": 0.0195,
+ "num_timesteps_cond": 1,
+ "log_every_t": 200,
+ "timesteps": 1000,
+ "unconditional_prob_cfg": 0.1,
+ "parameterization": "eps",
+ "first_stage_key": "fbank",
+ "latent_t_size": 256,
+ "latent_f_size": 16,
+ "channels": 8,
+ "monitor": "val/loss_simple_ema",
+ "scale_by_std": True,
+ "unet_config": {
+ "target": "audioldm2.latent_diffusion.modules.diffusionmodules.openaimodel.UNetModel",
+ "params": {
+ "image_size": 64,
+ "context_dim": [768, 1024],
+ "in_channels": 8,
+ "out_channels": 8,
+ "model_channels": 128,
+ "attention_resolutions": [8, 4, 2],
+ "num_res_blocks": 2,
+ "channel_mult": [1, 2, 3, 5],
+ "num_head_channels": 32,
+ "use_spatial_transformer": True,
+ "transformer_depth": 1,
+ },
+ },
+ "evaluation_params": {
+ "unconditional_guidance_scale": 3.5,
+ "ddim_sampling_steps": 200,
+ "n_candidates_per_samples": 3,
+ },
+ "cond_stage_config": {
+ "crossattn_audiomae_generated": {
+ "cond_stage_key": "all",
+ "conditioning_key": "crossattn",
+ "target": "audioldm2.latent_diffusion.modules.encoders.modules.SequenceGenAudioMAECond",
+ "params": {
+ "always_output_audiomae_gt": False,
+ "learnable": True,
+ "device": "cuda",
+ "use_gt_mae_output": True,
+ "use_gt_mae_prob": 0.25,
+ "base_learning_rate": 0.0002,
+ "sequence_gen_length": 8,
+ "use_warmup": True,
+ "sequence_input_key": [
+ "film_clap_cond1",
+ "crossattn_flan_t5",
+ ],
+ "sequence_input_embed_dim": [512, 1024],
+ "batchsize": 16,
+ "cond_stage_config": {
+ "film_clap_cond1": {
+ "cond_stage_key": "text",
+ "conditioning_key": "film",
+ "target": "audioldm2.latent_diffusion.modules.encoders.modules.CLAPAudioEmbeddingClassifierFreev2",
+ "params": {
+ "sampling_rate": 48000,
+ "embed_mode": "text",
+ "amodel": "HTSAT-base",
+ },
+ },
+ "crossattn_flan_t5": {
+ "cond_stage_key": "text",
+ "conditioning_key": "crossattn",
+ "target": "audioldm2.latent_diffusion.modules.encoders.modules.FlanT5HiddenState",
+ },
+ "crossattn_audiomae_pooled": {
+ "cond_stage_key": "ta_kaldi_fbank",
+ "conditioning_key": "crossattn",
+ "target": "audioldm2.latent_diffusion.modules.encoders.modules.AudioMAEConditionCTPoolRand",
+ "params": {
+ "regularization": False,
+ "no_audiomae_mask": True,
+ "time_pooling_factors": [8],
+ "freq_pooling_factors": [8],
+ "eval_time_pooling": 8,
+ "eval_freq_pooling": 8,
+ "mask_ratio": 0,
+ },
+ },
+ },
+ },
+ },
+ "crossattn_flan_t5": {
+ "cond_stage_key": "text",
+ "conditioning_key": "crossattn",
+ "target": "audioldm2.latent_diffusion.modules.encoders.modules.FlanT5HiddenState",
+ },
+ },
+ },
+ },
+ }
+ return basic_config
+
+
+def get_metadata():
+ return {
+ "audioldm2-full": {
+ "path": os.path.join(
+ CACHE_DIR,
+ "audioldm2-full.pth",
+ ),
+ "url": "https://huggingface.co/haoheliu/audioldm2-full/resolve/main/audioldm2-full.pth",
+ },
+ }
+
+
+class MyProgressBar:
+ def __init__(self):
+ self.pbar = None
+
+ def __call__(self, block_num, block_size, total_size):
+ if not self.pbar:
+ self.pbar = progressbar.ProgressBar(maxval=total_size)
+ self.pbar.start()
+
+ downloaded = block_num * block_size
+ if downloaded < total_size:
+ self.pbar.update(downloaded)
+ else:
+ self.pbar.finish()
+
+
+def download_checkpoint(checkpoint_name="audioldm2-full"):
+ meta = get_metadata()
+ if checkpoint_name not in meta.keys():
+ print(
+ "The model name you provided is not supported. Please use one of the following: ",
+ meta.keys(),
+ )
+
+ model_id = "haoheliu/%s" % checkpoint_name
+ hf_hub_download(
+ repo_id=model_id,
+ filename=os.path.basename(meta[checkpoint_name]["path"]),
+ local_dir=os.path.dirname(meta[checkpoint_name]["path"]),
+ )
diff --git a/batch.lst b/batch.lst
new file mode 100644
index 0000000000000000000000000000000000000000..c52c7a52523775ad729bfbb350f9cd70ddfbf3e4
--- /dev/null
+++ b/batch.lst
@@ -0,0 +1,4 @@
+A forest of wind chimes singing a soothing melody in the breeze.
+A violin playing a heartfelt melody.
+A saxophone playing a soulful melody.
+Musical constellations twinkling in the night sky, forming a cosmic melody.
\ No newline at end of file
diff --git a/bg.png b/bg.png
new file mode 100644
index 0000000000000000000000000000000000000000..2811a3593a7492c5af5754ab6949a6e60f2635bd
Binary files /dev/null and b/bg.png differ
diff --git a/bin/audioldm2 b/bin/audioldm2
new file mode 100755
index 0000000000000000000000000000000000000000..2ff95674ad63326a4b0d7f7b633c2f87949502f3
--- /dev/null
+++ b/bin/audioldm2
@@ -0,0 +1,131 @@
+#!/usr/bin/python3
+import os
+import torch
+import logging
+from audioldm2 import text_to_audio, build_model, save_wave, get_time, read_list
+import argparse
+
+os.environ["TOKENIZERS_PARALLELISM"] = "true"
+matplotlib_logger = logging.getLogger('matplotlib')
+matplotlib_logger.setLevel(logging.WARNING)
+
+
+CACHE_DIR = os.getenv(
+ "AUDIOLDM_CACHE_DIR",
+ os.path.join(os.path.expanduser("~"), ".cache/audioldm2"))
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument(
+ "-t",
+ "--text",
+ type=str,
+ required=False,
+ default="",
+ help="Text prompt to the model for audio generation",
+)
+
+parser.add_argument(
+ "-tl",
+ "--text_list",
+ type=str,
+ required=False,
+ default="",
+ help="A file that contains text prompt to the model for audio generation",
+)
+
+parser.add_argument(
+ "-s",
+ "--save_path",
+ type=str,
+ required=False,
+ help="The path to save model output",
+ default="./output",
+)
+
+parser.add_argument(
+ "--model_name",
+ type=str,
+ required=False,
+ help="The checkpoint you gonna use",
+ default="audioldm2-full",
+ choices=["audioldm2-full"]
+)
+
+parser.add_argument(
+ "-b",
+ "--batchsize",
+ type=int,
+ required=False,
+ default=1,
+ help="Generate how many samples at the same time",
+)
+
+parser.add_argument(
+ "--ddim_steps",
+ type=int,
+ required=False,
+ default=200,
+ help="The sampling step for DDIM",
+)
+
+parser.add_argument(
+ "-gs",
+ "--guidance_scale",
+ type=float,
+ required=False,
+ default=3.5,
+ help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
+)
+
+parser.add_argument(
+ "-n",
+ "--n_candidate_gen_per_text",
+ type=int,
+ required=False,
+ default=3,
+ help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
+)
+
+parser.add_argument(
+ "--seed",
+ type=int,
+ required=False,
+ default=0,
+ help="Change this value (any integer number) will lead to a different generation result.",
+)
+
+args = parser.parse_args()
+
+torch.set_float32_matmul_precision("high")
+
+save_path = os.path.join(args.save_path, get_time())
+
+text = args.text
+random_seed = args.seed
+duration = 10
+guidance_scale = args.guidance_scale
+n_candidate_gen_per_text = args.n_candidate_gen_per_text
+
+os.makedirs(save_path, exist_ok=True)
+audioldm2 = build_model(model_name=args.model_name)
+
+if(args.text_list):
+ print("Generate audio based on the text prompts in %s" % args.text_list)
+ prompt_todo = read_list(args.text_list)
+else:
+ prompt_todo = [text]
+
+for text in prompt_todo:
+ waveform = text_to_audio(
+ audioldm2,
+ text,
+ seed=random_seed,
+ duration=duration,
+ guidance_scale=guidance_scale,
+ ddim_steps=args.ddim_steps,
+ n_candidate_gen_per_text=n_candidate_gen_per_text,
+ batchsize=args.batchsize,
+ )
+
+ save_wave(waveform, save_path, name=text)
diff --git a/bin/audioldm2.cmd b/bin/audioldm2.cmd
new file mode 100755
index 0000000000000000000000000000000000000000..c164fbfb6a194858b6d9019c8e29df3e57b3172a
--- /dev/null
+++ b/bin/audioldm2.cmd
@@ -0,0 +1,2 @@
+@echo OFF
+python -m audioldm2 %*
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ea2a8af90d37d8a1994cc87a7731b03ea065cd8f
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,6 @@
+git+https://github.com/huggingface/diffusers.git
+git+https://github.com/huggingface/transformers.git
+--extra-index-url https://download.pytorch.org/whl/cu113
+torch >= 2.0
+huggingface_hub
+transformers==4.30.2
\ No newline at end of file
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..69f32a48bc29a8e293efbbf3fb05080e940de108
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,158 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+# python3 setup.py sdist bdist_wheel
+"""
+@File : setup.py.py
+@Contact : haoheliu@gmail.com
+@License : (C)Copyright 2020-2100
+
+@Modify Time @Author @Version @Desciption
+------------ ------- -------- -----------
+9/6/21 5:16 PM Haohe Liu 1.0 None
+"""
+
+# !/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# Note: To use the 'upload' functionality of this file, you must:
+# $ pipenv install twine --dev
+
+import io
+import os
+import sys
+from shutil import rmtree
+
+from setuptools import find_packages, setup, Command
+
+# Package meta-data.
+NAME = "audioldm2"
+DESCRIPTION = "This package is written for text-to-audio/music generation."
+URL = "https://github.com/haoheliu/audioldm2"
+EMAIL = "haoheliu@gmail.com"
+AUTHOR = "Haohe Liu"
+REQUIRES_PYTHON = ">=3.7.0"
+VERSION = "0.0.2"
+
+# What packages are required for this module to be executed?
+REQUIRED = [
+ "torch>=1.13.0",
+ "torchaudio>=0.13.0",
+ "torchvision>=0.14.0",
+ "tqdm",
+ "gradio",
+ "pyyaml",
+ "einops",
+ "chardet",
+ "numpy<=1.23.5",
+ "soundfile",
+ "librosa==0.9.2",
+ "scipy",
+ "pandas",
+ "torchlibrosa==0.0.9",
+ "transformers",
+ "progressbar",
+ "ftfy",
+]
+
+# What packages are optional?
+EXTRAS = {}
+
+# The rest you shouldn't have to touch too much :)
+# ------------------------------------------------
+# Except, perhaps the License and Trove Classifiers!
+# If you do change the License, remember to change the Trove Classifier for that!
+
+here = os.path.abspath(os.path.dirname(__file__))
+
+# Import the README and use it as the long-description.
+# Note: this will only work if 'README.md' is present in your MANIFEST.in file!
+try:
+ with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f:
+ long_description = "\n" + f.read()
+except FileNotFoundError:
+ long_description = DESCRIPTION
+
+# Load the package's __version__.py module as a dictionary.
+about = {}
+if not VERSION:
+ project_slug = NAME.lower().replace("-", "_").replace(" ", "_")
+ with open(os.path.join(here, project_slug, "__version__.py")) as f:
+ exec(f.read(), about)
+else:
+ about["__version__"] = VERSION
+
+
+class UploadCommand(Command):
+ """Support setup.py upload."""
+
+ description = "Build and publish the package."
+ user_options = []
+
+ @staticmethod
+ def status(s):
+ """Prints things in bold."""
+ print("\033[1m{0}\033[0m".format(s))
+
+ def initialize_options(self):
+ pass
+
+ def finalize_options(self):
+ pass
+
+ def run(self):
+ try:
+ self.status("Removing previous builds…")
+ rmtree(os.path.join(here, "dist"))
+ except OSError:
+ pass
+
+ self.status("Building Source and Wheel (universal) distribution…")
+ os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable))
+
+ self.status("Uploading the package to PyPI via Twine…")
+ os.system("twine upload dist/*")
+
+ self.status("Pushing git tags…")
+ os.system("git tag v{0}".format(about["__version__"]))
+ os.system("git push --tags")
+
+ sys.exit()
+
+
+# Where the magic happens:
+setup(
+ name=NAME,
+ version=about["__version__"],
+ description=DESCRIPTION,
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ author=AUTHOR,
+ author_email=EMAIL,
+ python_requires=REQUIRES_PYTHON,
+ url=URL,
+ # packages=find_packages(exclude=[]),
+ # If your package is a single module, use this instead of 'packages':
+ # entry_points={
+ # 'console_scripts': ['mycli=mymodule:cli'],
+ # },
+ install_requires=REQUIRED,
+ extras_require=EXTRAS,
+ packages=find_packages(),
+ include_package_data=True,
+ license="MIT",
+ classifiers=[
+ # Trove classifiers
+ # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
+ "License :: OSI Approved :: MIT License",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: Implementation :: CPython",
+ "Programming Language :: Python :: Implementation :: PyPy",
+ ],
+ # $ setup.py publish support.
+ cmdclass={
+ "upload": UploadCommand,
+ },
+ scripts=["bin/audioldm2.cmd", "bin/audioldm2"],
+)
diff --git a/share_btn.py b/share_btn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0378607680fa5468e9034d230f546f5f0913ae0
--- /dev/null
+++ b/share_btn.py
@@ -0,0 +1,74 @@
+community_icon_html = """
+
+
+ """
+
+loading_icon_html = """ """
+
+share_js = """async () => {
+ async function uploadFile(file){
+ const UPLOAD_URL = 'https://huggingface.co/uploads';
+ const response = await fetch(UPLOAD_URL, {
+ method: 'POST',
+ headers: {
+ 'Content-Type': file.type,
+ 'X-Requested-With': 'XMLHttpRequest',
+ },
+ body: file, /// <- File inherits from Blob
+ });
+ const url = await response.text();
+ return url;
+ }
+ async function getInputVideoFile(videoEl){
+ const res = await fetch(videoEl.src);
+ const blob = await res.blob();
+ const videoId = Date.now() % 200;
+ const fileName = `sd-perception-${{videoId}}.mp4`;
+ return new File([blob], fileName, { type: 'video/mp4' });
+ }
+
+ async function audioToBase64(audioFile) {
+ return new Promise((resolve, reject) => {
+ let reader = new FileReader();
+ reader.readAsDataURL(audioFile);
+ reader.onload = () => resolve(reader.result);
+ reader.onerror = error => reject(error);
+
+ });
+ }
+ const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
+ const inputPromptEl = gradioEl.querySelector('#prompt-in input').value;
+ const outputVideoEl = gradioEl.querySelector('#output-video video');
+
+ let titleTxt = `Text-to-Audio: ${inputPromptEl}`;
+
+ const shareBtnEl = gradioEl.querySelector('#share-btn');
+ const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
+ const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
+ if(!outputVideoEl){
+ return;
+ };
+ shareBtnEl.style.pointerEvents = 'none';
+ shareIconEl.style.display = 'none';
+ loadingIconEl.style.removeProperty('display');
+ const outputVideo = await getInputVideoFile(outputVideoEl);
+ const urlOutputVideo = await uploadFile(outputVideo);
+
+ const descriptionMd = `
+##### ${inputPromptEl}
+
+${urlOutputVideo}
+`;
+ const params = new URLSearchParams({
+ title: titleTxt,
+ description: descriptionMd,
+ });
+ const paramsStr = params.toString();
+ window.open(`https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation/discussions/new?${paramsStr}`, '_blank');
+ shareBtnEl.style.removeProperty('pointer-events');
+ shareIconEl.style.removeProperty('display');
+ loadingIconEl.style.display = 'none';
+}"""
diff --git a/tests/code_coverage.py b/tests/code_coverage.py
new file mode 100644
index 0000000000000000000000000000000000000000..deb035e9fedffacd8bf3a9c37d5566fa8fd4e819
--- /dev/null
+++ b/tests/code_coverage.py
@@ -0,0 +1,3 @@
+import os
+
+os.system('python3 bin/audioldm2 -t "A toilet flushing and water trickling"')
diff --git a/tests/code_coverage.sh b/tests/code_coverage.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0a5920c645c262c80436ff586e7ad4825e9e5622
--- /dev/null
+++ b/tests/code_coverage.sh
@@ -0,0 +1 @@
+pytest --cov=src tests/*
\ No newline at end of file